<?xml version="1.0" encoding="utf-8" standalone="yes"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/">
  <channel>
    <title>周鑫的个人博客</title>
    <link>https://www.zhouxin.space/</link>
    <description>Recent content on 周鑫的个人博客</description>
    <generator>Hugo -- 0.153.3</generator>
    <language>zh-cn</language>
    <lastBuildDate>Mon, 02 Feb 2026 00:17:00 +0800</lastBuildDate>
    <atom:link href="https://www.zhouxin.space/index.xml" rel="self" type="application/rss+xml" />
    <item>
      <title>CS336 学习笔记之第八讲：并行策略简单实现</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-8-parallelism2/</link>
      <pubDate>Tue, 20 Jan 2026 13:01:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-8-parallelism2/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;TL;DR 本讲将分布式通信概念落地为可执行代码，通过 PyTorch 和 NCCL 展示了集合通信的实际实现，并手写了三种并行策略的简化版本。从 All-Reduce 的基准测试到 MLP 的切分实现，揭示了通信开销与计算模式的核心逻辑。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>TL;DR 本讲将分布式通信概念落地为可执行代码，通过 PyTorch 和 NCCL 展示了集合通信的实际实现，并手写了三种并行策略的简化版本。从 All-Reduce 的基准测试到 MLP 的切分实现，揭示了通信开销与计算模式的核心逻辑。</p>
</blockquote>
<h2 id="通信原语从接口到底层">通信原语：从接口到底层</h2>
<p>集合通信在 PyTorch 中的接口简洁直白。初始化进程组后，每个进程执行相同的通信操作，NCCL 自动处理底层数据流动。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1"># 初始化进程组（每个进程都要执行）</span>
</span></span><span class="line"><span class="cl"><span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&#34;MASTER_ADDR&#34;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&#34;localhost&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&#34;MASTER_PORT&#34;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&#34;15623&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">dist</span><span class="o">.</span><span class="n">init_process_group</span><span class="p">(</span><span class="s2">&#34;nccl&#34;</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="o">=</span><span class="n">world_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># All-Reduce 示例</span>
</span></span><span class="line"><span class="cl"><span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="o">+</span> <span class="n">rank</span>
</span></span><span class="line"><span class="cl"><span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">dist</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>  <span class="c1"># 原地修改</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&#34;Rank </span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="n">tensor</span><span class="si">}</span><span class="s2">&#34;</span><span class="p">)</span>  <span class="c1"># 所有进程输出相同的求和结果</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>基准测试揭示了硬件带宽的实际限制。测量 All-Reduce 的有效带宽时，需要计算传输数据总量和总耗时：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1"># 基准测试核心逻辑</span>
</span></span><span class="line"><span class="cl"><span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">dist</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">duration</span> <span class="o">=</span> <span class="n">end_time</span> <span class="o">-</span> <span class="n">start_time</span>
</span></span><span class="line"><span class="cl"><span class="n">size_bytes</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="n">tensor</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">sent_bytes</span> <span class="o">=</span> <span class="n">size_bytes</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">world_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>  <span class="c1"># 发送和接收</span>
</span></span><span class="line"><span class="cl"><span class="n">bandwidth</span> <span class="o">=</span> <span class="n">sent_bytes</span> <span class="o">/</span> <span class="p">(</span><span class="n">world_size</span> <span class="o">*</span> <span class="n">duration</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&#34;带宽：</span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">bandwidth</span> <span class="o">/</span> <span class="mi">1024</span><span class="o">**</span><span class="mi">3</span><span class="p">)</span><span class="si">}</span><span class="s2"> GB/s&#34;</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="数据并行梯度同步的集体操作">数据并行：梯度同步的集体操作</h2>
<p>数据并行的核心是在反向传播后同步梯度。每个进程处理部分数据，计算局部梯度，然后通过 All-Reduce 得到全局平均梯度。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">data_parallelism_main</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">setup</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 数据切片：每个进程获取部分批次</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_size</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">local_batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">//</span> <span class="n">world_size</span>
</span></span><span class="line"><span class="cl">    <span class="n">local_data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">rank</span><span class="o">*</span><span class="n">local_batch_size</span><span class="p">:(</span><span class="n">rank</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">local_batch_size</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 完整模型副本</span>
</span></span><span class="line"><span class="cl">    <span class="n">params</span> <span class="o">=</span> <span class="p">[</span><span class="n">get_init_params</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 前向传播（使用本地数据）</span>
</span></span><span class="line"><span class="cl">    <span class="n">x</span> <span class="o">=</span> <span class="n">local_data</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">@</span> <span class="n">param</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">square</span><span class="p">()</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 反向传播</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 关键：梯度同步</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">tensor</span><span class="o">=</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">dist</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">AVG</span><span class="p">,</span> <span class="n">async_op</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 参数更新（各进程同步更新相同参数）</span>
</span></span><span class="line"><span class="cl">    <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="张量并行层内矩阵的切分与聚合">张量并行：层内矩阵的切分与聚合</h2>
<p>张量并行将权重矩阵按列切分，每个进程计算部分输出，然后通过 All-Gather 聚合完整结果。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">tensor_parallelism_main</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">setup</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">num_dim</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">local_num_dim</span> <span class="o">=</span> <span class="n">num_dim</span> <span class="o">//</span> <span class="n">world_size</span>  <span class="c1"># 特征维度切分</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 每个进程只持有部分参数</span>
</span></span><span class="line"><span class="cl">    <span class="n">params</span> <span class="o">=</span> <span class="p">[</span><span class="n">get_init_params</span><span class="p">(</span><span class="n">num_dim</span><span class="p">,</span> <span class="n">local_num_dim</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 前向传播</span>
</span></span><span class="line"><span class="cl">    <span class="n">x</span> <span class="o">=</span> <span class="n">data</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># 局部矩阵乘法</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>  <span class="c1"># 输出形状：[batch_size, local_num_dim]</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1"># 聚合所有进程的部分结果</span>
</span></span><span class="line"><span class="cl">        <span class="n">activations</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">world_size</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">        <span class="n">dist</span><span class="o">.</span><span class="n">all_gather</span><span class="p">(</span><span class="n">tensor_list</span><span class="o">=</span><span class="n">activations</span><span class="p">,</span> <span class="n">tensor</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">async_op</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1"># 拼接得到完整特征</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">activations</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># 形状：[batch_size, num_dim]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="流水线并行层间激活值的流动">流水线并行：层间激活值的流动</h2>
<p>流水线并行将模型按深度切分，激活值在进程间传递。使用微批次可以减少流水线气泡。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">pipeline_parallelism_main</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">num_micro_batches</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">setup</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 层分配</span>
</span></span><span class="line"><span class="cl">    <span class="n">local_num_layers</span> <span class="o">=</span> <span class="n">num_layers</span> <span class="o">//</span> <span class="n">world_size</span>
</span></span><span class="line"><span class="cl">    <span class="n">local_params</span> <span class="o">=</span> <span class="p">[</span><span class="n">get_init_params</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">local_num_layers</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 微批次划分</span>
</span></span><span class="line"><span class="cl">    <span class="n">micro_batch_size</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">//</span> <span class="n">num_micro_batches</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">micro_batches</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="n">num_micro_batches</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">micro_batches</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">micro_batch_size</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> 
</span></span><span class="line"><span class="cl">                        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_micro_batches</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 前向流水线</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">micro_batch</span> <span class="ow">in</span> <span class="n">micro_batches</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># 从上一进程接收（如果不是第一个进程）</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">rank</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">dist</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">tensor</span><span class="o">=</span><span class="n">micro_batch</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">rank</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1"># 本地计算</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">local_params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">micro_batch</span> <span class="o">=</span> <span class="n">micro_batch</span> <span class="o">@</span> <span class="n">param</span>
</span></span><span class="line"><span class="cl">            <span class="n">micro_batch</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">micro_batch</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1"># 发送到下一进程（如果不是最后一个进程）</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">rank</span> <span class="o">&lt;</span> <span class="n">world_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">dist</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">tensor</span><span class="o">=</span><span class="n">micro_batch</span><span class="p">,</span> <span class="n">dst</span><span class="o">=</span><span class="n">rank</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="总结代码揭示的模式">总结：代码揭示的模式</h2>
<p>混合并行策略的选择在代码层面体现为通信原语的组合。数据并行的 All-Reduce、张量并行的 All-Gather、流水线并行的 Send/Recv，这些操作共同构成了分布式训练的通信骨架。</p>
<p>实际系统会在这些基础模式上添加优化：通信与计算的重叠、梯度检查点、异步执行等。但理解这些基础实现是分析和优化分布式训练性能的起点。</p>
<p>注：以上代码为教学用简化实现，省略了错误处理、设备管理、性能优化等工程细节。实际项目应使用 PyTorch 的 DistributedDataParallel、FullyShardedDataParallel 或第三方框架如 DeepSpeed。</p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第七讲：并行策略</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-7-parallelism/</link>
      <pubDate>Sat, 17 Jan 2026 17:23:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-7-parallelism/</guid>
      <description>&lt;h2 id=&#34;分布式通信原语&#34;&gt;分布式通信原语&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;All reduce：根进程 $P_r$ 拥有一个大小为 $K$ 的消息缓冲区 $B$。操作结束后，通信组内的所有进程 $P_i$ 都拥有了 $B$ 的完全一致的副本。&lt;/li&gt;
&lt;li&gt;Scatter：根进程 $P_r$ 拥有一个大向量 $V$，将其均匀切分为 $N$ 个块，记为 $v_0, v_1, \dots, v_{N-1}$。操作结束后，$P_i$ 仅获得对应的分块 $v_i$。&lt;/li&gt;
&lt;li&gt;Gather：Scatter 的逆操作。每个进程 $P_i$ 拥有一个数据块 $v_i$。操作结束后，根进程 $P_r$ 将这些块按 Rank 顺序拼接，在其内存中形成完整向量 $V = [v_0, v_1, \dots, v_{N-1}]$。&lt;/li&gt;
&lt;li&gt;All-Gather：每个进程 $P_i$ 拥有数据块 $v_i$。操作相当于先执行 Gather，再执行 Broadcast。操作结束后，每一个进程 $P_i$ 都拥有完整的向量 $V = [v_0, v_1, \dots, v_{N-1}]$。&lt;/li&gt;
&lt;li&gt;Reduce：每个进程 $P_i$ 拥有一个向量 $X_i$。操作结束后，根进程 $P_r$ 获得结果 $R = X_0 \oplus X_1 \oplus \dots \oplus X_{N-1}$。&lt;/li&gt;
&lt;li&gt;All-Reduce：逻辑上等价于 Reduce + Broadcast。操作结束后，所有进程 $P_i$ 都获得了完全相同的规约结果 $R = \sum_{j=0}^{N-1} X_j$。&lt;/li&gt;
&lt;li&gt;Reduce-Scatter：逻辑上等价于先对全局数据做 Reduce 得到结果 $R$，然后将 $R$ Scatter 给各个进程。操作结束后，$P_i$ 获得结果向量的第 $i$ 个分块。&lt;br&gt;
&lt;img alt=&#34;常见集合通信原语示意图&#34; loading=&#34;lazy&#34; src=&#34;https://pics.zhouxin.space/20260118124019476.webp&#34;&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;llm-训练中不同的并行策略&#34;&gt;LLM 训练中不同的并行策略&lt;/h2&gt;
&lt;h3 id=&#34;朴素数据并行-dp&#34;&gt;朴素数据并行 DP&lt;/h3&gt;
&lt;p&gt;数据并行中每个 GPU 都有一个完整的模型副本，输入 Batch 会被切分为多个 Mini Batch 并喂给不同的 GPU，每个 GPU 正常计算出梯度后，通过 All-Reduce 计算出全局梯度，再对参数进行更新。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h2 id="分布式通信原语">分布式通信原语</h2>
<ul>
<li>All reduce：根进程 $P_r$ 拥有一个大小为 $K$ 的消息缓冲区 $B$。操作结束后，通信组内的所有进程 $P_i$ 都拥有了 $B$ 的完全一致的副本。</li>
<li>Scatter：根进程 $P_r$ 拥有一个大向量 $V$，将其均匀切分为 $N$ 个块，记为 $v_0, v_1, \dots, v_{N-1}$。操作结束后，$P_i$ 仅获得对应的分块 $v_i$。</li>
<li>Gather：Scatter 的逆操作。每个进程 $P_i$ 拥有一个数据块 $v_i$。操作结束后，根进程 $P_r$ 将这些块按 Rank 顺序拼接，在其内存中形成完整向量 $V = [v_0, v_1, \dots, v_{N-1}]$。</li>
<li>All-Gather：每个进程 $P_i$ 拥有数据块 $v_i$。操作相当于先执行 Gather，再执行 Broadcast。操作结束后，每一个进程 $P_i$ 都拥有完整的向量 $V = [v_0, v_1, \dots, v_{N-1}]$。</li>
<li>Reduce：每个进程 $P_i$ 拥有一个向量 $X_i$。操作结束后，根进程 $P_r$ 获得结果 $R = X_0 \oplus X_1 \oplus \dots \oplus X_{N-1}$。</li>
<li>All-Reduce：逻辑上等价于 Reduce + Broadcast。操作结束后，所有进程 $P_i$ 都获得了完全相同的规约结果 $R = \sum_{j=0}^{N-1} X_j$。</li>
<li>Reduce-Scatter：逻辑上等价于先对全局数据做 Reduce 得到结果 $R$，然后将 $R$ Scatter 给各个进程。操作结束后，$P_i$ 获得结果向量的第 $i$ 个分块。<br>
<img alt="常见集合通信原语示意图" loading="lazy" src="https://pics.zhouxin.space/20260118124019476.webp"></li>
</ul>
<h2 id="llm-训练中不同的并行策略">LLM 训练中不同的并行策略</h2>
<h3 id="朴素数据并行-dp">朴素数据并行 DP</h3>
<p>数据并行中每个 GPU 都有一个完整的模型副本，输入 Batch 会被切分为多个 Mini Batch 并喂给不同的 GPU，每个 GPU 正常计算出梯度后，通过 All-Reduce 计算出全局梯度，再对参数进行更新。</p>
<ul>
<li>计算扩展性：每个 GPU 计算 B/M 个样本。</li>
<li>通信开销： <code>2x#params</code>，来自 All-Reduce 的过程。</li>
<li>内存拓展性：无。每个 GPU 必须完整存储模型。</li>
</ul>
<h3 id="zero">ZeRO</h3>
<p>如下图所示，一个参数量为 $\Phi$ 的模型，进行 AMP 训练时，显存占用大致为：</p>
<ul>
<li>参数占 2$\Phi$ （2 指的是字节，BF16 为两字节）</li>
<li>梯度占 2$\Phi$</li>
<li>优化器状态 12$\Phi$ （FP32 参数、FP32 动量、FP32 方差）<br>
<img alt="显存中参数构成" loading="lazy" src="https://pics.zhouxin.space/20260118141653335.webp"></li>
</ul>
<p>在朴素 DP 中， 每个 GPU 都要完全存储上面三个状态。微软 DeepSpeed 团队提出的 ZeRO 技术，逐步消除了上述三个部分的数据冗余。</p>
<ul>
<li>Stage 1：切分优化器状态<br>
在 Stage 1 中，优化器状态被切分到不同的 GPU 上，每个 GPU 负责对一部分参数进行更新，然后通过 All Gather 让每个 GPU 都有一完整的更新后的参数。</li>
</ul>
<p>具体来说，其算法步骤为：</p>
<ol>
<li>
<p>每个 GPU 计算 Mini Batch 上的梯度；</p>
</li>
<li>
<p>Reduce Scatter 梯度，这里产生了 <code>#params</code> 的通信开销；<br>
<img alt="Reduce Scatter 示意图" loading="lazy" src="https://pics.zhouxin.space/20260118143442592.webp"></p>
</li>
<li>
<p>每个 GPU 根据它们自己维护的优化器状态对一部分参数做更新；</p>
</li>
<li>
<p>All Gather 参数，这里产生了 <code>#params</code> 的通信开销。<br>
<img alt="All Gather 示意图" loading="lazy" src="https://pics.zhouxin.space/20260118143513214.webp"></p>
</li>
</ol>
<p>对比朴素 DP，Stage 1 就是将一个 All Reduce 操作替换为了 Reduce Scatter + All Gather，并将参数更新的步骤插入到二者之间。之前提过，All Reduce 与 Reduce Scatter + All Gather 是等价的，也就是说，我们在<strong>没有新增通信开销</strong>的情况下拿到了显著的内存收益。</p>
<p><img alt="DDP v.s. ZeRO Stage 1" loading="lazy" src="https://pics.zhouxin.space/20260118143929942.webp"></p>
<ul>
<li>Stage 2：切分梯度和优化器状态<br>
在 Stage 2 中，我们尝试在 Stage 1 的来切分梯度。切分梯度，也就是每块 GPU 保留一小片空间用来存放自己负责的那个切片的梯度。在 Stage 1 中，我们需要先计算完所有的梯度后再对梯度做 Reduce Scatter。为了减少梯度的占用空间，必须得当梯度算完之后立刻做 Reduce Scatter，将其发送到存储这片梯度的 GPU 上，然后当计算图用不到时扔掉不属于自己的梯度。</li>
</ul>
<p>具体来说，其算法步骤为：</p>
<ol>
<li>计算一层梯度后，立刻做 Reduce Scatter 将其规约到正确的 GPU 上；</li>
<li>一旦梯度在计算图中再也用不到了，立刻释放对应的空间；</li>
<li>每个 GPU 更新参数；</li>
<li>对参数做 All Reduce。</li>
</ol>
<p>相比 Stage 1，其通信开销保持不变。在 Stage 1 中对梯度做的一次性 Reduce Scatter 在 Stage 2 中被拆分为多次，但是彼此没有重合，所以整体通信开销不变。</p>
<ul>
<li>Stage 3 (FSDP)：切分梯度 + 优化器状态 + 参数<br>
在 Stage 3 中，参数也被切分，因此进行前向计算时，需要通过通信获取完整参数。如下图所示，在进行前向和反向计算钱，使用 All Gather 获取完整的权重再进行下一步的计算。<br>
<img alt="Stage 3 简单示意图" loading="lazy" src="https://pics.zhouxin.space/20260118174636468.webp"></li>
</ul>
<p>每次计算前都要做一次 All Gather，这里面的通信开销非常大，但是也并非慢到无法接受。这得益于 Stage 3 将通信和计算折叠：如下图所示，除了第一个前向计算 FWD0 之外，其他的前向计算需要的通信都可以与前一此计算重叠，如果前一层的计算时间大于拉取下一层的时间，通信延迟还能够被完全隐藏。<br>
<img alt="PixPin_2026-01-18_17-56-09.webp" loading="lazy" src="https://pics.zhouxin.space/20260118175611054.webp"></p>
<p>Stage 3 的通信量是 <code>3 x #params</code>，包括前向和反向各一次的 All Gather 和计算梯度时的 Reduce Scatter，是 Stage 2 的 1.5 倍。</p>
<h3 id="dp-的限制">DP 的限制</h3>
<p>Data Parallelism 不是万能，其有如下限制：</p>
<ol>
<li>DP 并行数必须小于 Batch Size。Batch Size 不是越大越好，当 Batch Size 过大时，其所需要的训练步数并不能同比下降，也就是模型并不能学的更快。而 DP 的并行数必须得小于 Batch Size，这就约束了 DP 并行的上限。<br>
<img alt="训练速度与 Batch Size 的关系" loading="lazy" src="https://pics.zhouxin.space/20260118181104365.webp"></li>
<li>ZeRO 各有优劣。Stage 1 和 2 几乎没有额外开销，但是没有切分模型；Stage 3 有额外的开销并且不能减少激活值的内存。</li>
</ol>
<h3 id="模型并行-data-parallelism">模型并行 Data Parallelism</h3>
<p>在 ZeRO Stage 3 中我们将对激活值做切分，从而解决数据并行中难以扩展内存的痛点。在模型并行中，也会将参数切分到不同的卡上，但是卡间交换的是激活值而非参数。模型并行有两类实现，包括流水线并行和张量并行。</p>
<h3 id="流水线并行-pipeline-parallelism">流水线并行 Pipeline Parallelism</h3>
<p>流水线并行指的是将不同的层切分到不同的卡上，数据像流水线一样在 GPU 之间流动。这是一个很符合直觉但是实现比较困难的并行策略。如下图所示，如果再等待其它卡的过程中什么也不干，每张卡只有 1/n 的时间处于激活状态，效率极其低下。<br>
<img alt="低效的流水线并行" loading="lazy" src="https://pics.zhouxin.space/20260118183529571.webp"></p>
<p>对此的解决策略是将一个 Batch 分割多个小 Batch，并依次送入流水线中。如下图所示，将数据分割为四条送入流水线，仍会有一些空泡，但是整体效率相比上面的朴素实现得到了提升。整体的空泡率为：$(n_\text{stages}-1)/n_\text{micro}$，Micro Batch 越多，空泡率越少。但是，与模型并行一样，Micro Batch 同样受限于 Batch Size。<br>
<img alt="流水线并行示意图" loading="lazy" src="https://pics.zhouxin.space/20260118184022732.webp"></p>
<p>既然 PP 也没多好，为啥使用它？</p>
<ol>
<li>相比 DP，PP 能够节省显存，它不仅对参数进行了切分，还对激活值进行了切分。</li>
<li>PP 只需要进行点对点的通信，很适合部署在通信链路效率不高的机器之间。</li>
</ol>
<p>最后还介绍了 Zero Bubble Pipeline。具体来说，反向计算可以分成两步，第一步是计算对输入的微分 B，第二步是计算对权重的微分 W。其中，对输入的微分是前一层依赖的 OutGrad，必须尽可能快地计算；后者是不被其它卡依赖的，只要在梯度更新前算完即可。Zero Bubble Pipeline 的核心思想就是优先完成 B，当流水线空闲时再完成 W。当然，这只是一个 High level 的介绍，老师一直在强调在工程实现中不管是这个 Zero Bubble Pipeline 还是单纯 PP，其实现起来都相当困难。</p>
<p><img alt="PixPin_2026-01-18_18-53-29.webp" loading="lazy" src="https://pics.zhouxin.space/20260118185332777.webp"></p>
<h3 id="张量并行-tensor-parallelism">张量并行 Tensor Parallelism</h3>
<p>流水线并行可以认为是沿着纵轴对模型切分，张量并行则是对横轴切分，即将一个大的矩阵运算切分到两张卡上，如下图所示，<code>X@A = X1@A1 + X2@A2 = Y</code>。<br>
<img alt="矩乘切分" loading="lazy" src="https://pics.zhouxin.space/20260119093934738.webp"><br>
下图展示了一个 MLP 网络中的标准切分方式：第一个矩乘的权重按列切分，第二个按行切分。</p>
<p>前向的流程是：输入 X 通过恒等映射 f 复制到所有卡上，GPU 1 计算出 XA1，GPU 2 计算出 XA2，由于激活函数是逐元素独立计算的，因此无需与其它 GPU 通信。此时 GPU 1 正好持有按按列切分的 Y1，GPU 2 持有 Y2，这与行切分的 B 在维度上完美对应，不需要进行 All Gather 操作即可继续流转到第二个矩乘上。然后再经过 All Reduce 操作 g 对输出进行聚合，就可以得到完整的 Z。</p>
<p>反向的流程与前向类似，只是操作 g 变成了恒等映射，操作 f 变成了 All Reduce 操作。<br>
<img alt="MLP 网络中 TP 切分方式" loading="lazy" src="https://pics.zhouxin.space/20260119094422341.webp"></p>
<p>TP 在每一层中都需要进行大量的通信，因此其对通信要求很高。如下图所示，当 TP 并行度大于 8 时，单卡吞吐量显著下降。同一个节点内的卡通信速度很快，通常 TP 在一个节点内的 8 张卡上部署。</p>
<p><img alt="单卡吞吐量随 TP 数变化" loading="lazy" src="https://pics.zhouxin.space/20260119100413858.webp"></p>
<h3 id="序列并行-sequence-parallelism">序列并行 Sequence Parallelism</h3>
<p>目前介绍的 TP 和 PP 都是对模型参数进行切分的，但是没有切分激活值。在训练过程中，模型的激活值可以通过以下公式来估算：</p>


<div>$$

\text{Activations memory per layer} = sbh \left(34 &#43; 5 \frac{as}{h}\right)

$$</div>

<p>其中变量含义如下：</p>
<ul>
<li>$s$ (sequence length): 序列长度</li>
<li>$b$ (microbatch size): 微批次大小</li>
<li>$h$ (hidden dimension size): 隐藏层维度</li>
<li>$a$ (number of attention heads): 注意力头数</li>
</ul>
<p>激活值由两部构成，第一部分是线性项，来自 MLP 、Attention 中的线性投影和 LayerNorm；第二部分是二次项 $5bas^2$，这部分来自 Attention 机制，包括注意力分数、Softmax 结果等，这部分可以通过 Flash Attention 的重计算消除。</p>
<p>应用 TP 后，显存占比中线性部分中的系数 24 可以减少为 $1/t$，其中 t 是 TP 并行度。，没有变化的系数 10 来自非 Matmul 的部分，包括 LayerNorm（4）、Dropout（2）、Attention 和 MLP 的输入（4）。</p>


<div>$$

\text{Activations memory per layer} = sbh \left(10&#43;\frac{24}{t} &#43; 5 \frac{as}{h}\right)

$$</div>

<p>以 LayerNorm 为例，其对每一条序列独立地做 Norm，因此可以通过在序列维度上进行切分，从而降低显存占用。如下图所示，在 LayerNorm 和 Dropout 上做了序列并行，在前向中，LayerNorm 的输出需要经过一次 $g$ All Gather 以将完整 Tensor 送入 Attention，Dropout 的输入需要经过一次 $\bar{g}$ Reduce Scatter 以将完整 Tensor 按序列切分。在反向中，$g$ 和 $\bar{g}$ 操作交换。<br>
<img alt="序列并行示意图" loading="lazy" src="https://pics.zhouxin.space/20260119104646827.webp"></p>
<p>应用 SP + FA 后，模型整体的激活值可以降低为 $sbh \left(\frac{34}{t}\right)$。</p>
<table>
  <thead>
      <tr>
          <th><strong>Configuration</strong></th>
          <th><strong>Activations Memory Per Transformer Layer</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>no parallelism</td>
          <td>$sbh \left(34 + 5\frac{as}{h}\right)$</td>
      </tr>
      <tr>
          <td>tensor parallel (baseline)</td>
          <td>$sbh \left(10 + \frac{24}{t} + 5\frac{as}{ht}\right)$</td>
      </tr>
      <tr>
          <td>tensor + sequence parallel</td>
          <td>$sbh \left(\frac{34}{t} + 5\frac{as}{ht}\right)$</td>
      </tr>
      <tr>
          <td>tensor parallel + selective activation recomputation</td>
          <td>$sbh \left(10 + \frac{24}{t}\right)$</td>
      </tr>
      <tr>
          <td>tensor parallel + sequence parallel + selective activation recomputation</td>
          <td>$sbh \left(\frac{34}{t}\right)$</td>
      </tr>
  </tbody>
</table>
<h3 id="ring-attention-和-专家并行">Ring Attention 和 专家并行</h3>
<p>最后简单介绍了下 Ring Attention 和专家并行。</p>
<p>Ring Attention 解决的痛点是超长序列中单卡装不下 KV Cache，其核心思想是将 KV 进行切分，GPUs 之间以环的形式流转 $KV_i$，并计算出相应的结果。</p>
<p>专家并行则是将专家散布到不同的 GPU 上，会面临来自负载均衡和计算通信 Overlap 的挑战。</p>
<h2 id="使用并行策略扩大-lm-的规模并训练">使用并行策略扩大 LM 的规模并训练</h2>
<h3 id="3d-并行">3D 并行</h3>
<p>3D 并行指的是混合使用 DP、PP 和 TP 策略，一个简单的原则是：</p>
<ul>
<li>先将模型切分到单卡放得下
<ul>
<li>使用 TP 策略，TP 并行度不大于每个节点内 GPU 的数量</li>
<li>在节点之间使用 PP 或者 ZeRO Stage 3，取决于带宽大小</li>
</ul>
</li>
<li>扩大 DP 规模直至达到 GPU 上限</li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第六讲：Kernels 和 Triton</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-6-kernels-and-triton/</link>
      <pubDate>Tue, 13 Jan 2026 12:53:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-6-kernels-and-triton/</guid>
      <description>&lt;hr&gt;
&lt;blockquote&gt;
&lt;p&gt;TL;DR 本讲是 CS336 系列笔记的第六讲。本节从算子融合的必要性出发，横向对比了手写 CUDA、使用 Triton 以及 PyTorch 2.0 编译技术三种实现方式。重点解析了 Triton 如何通过“块级（Block-level）”抽象简化显存管理，并以 GeLU、Softmax 和 Matmul 为例，展示了利用共享内存（SRAM）和分块（Tiling）技术打破访存墙的关键技巧。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<hr>
<blockquote>
<p>TL;DR 本讲是 CS336 系列笔记的第六讲。本节从算子融合的必要性出发，横向对比了手写 CUDA、使用 Triton 以及 PyTorch 2.0 编译技术三种实现方式。重点解析了 Triton 如何通过“块级（Block-level）”抽象简化显存管理，并以 GeLU、Softmax 和 Matmul 为例，展示了利用共享内存（SRAM）和分块（Tiling）技术打破访存墙的关键技巧。</p>
</blockquote>
<h2 id="算子融合-kernel-fusion">算子融合 (Kernel Fusion)</h2>
<h3 id="动机warehouse-vs-factory">动机：Warehouse vs Factory</h3>
<p>老师有一个非常经典的类比：</p>
<ul>
<li>DRAM (显存) 就像是一个巨大的仓库（Warehouse），容量大但存取慢。</li>
<li>SRAM (共享内存/寄存器) 就像是工厂（Factory）的车间，容量小但处理快。</li>
</ul>
<p>在执行深度学习模型时，如果不做算子融合，数据需要在仓库和工厂之间反复搬运：</p>
<ol>
<li>从 DRAM 读取数据 $x$ -&gt; 计算乘法 -&gt; 写回 DRAM</li>
<li>从 DRAM 读取数据 -&gt; 计算加法 -&gt; 写回 DRAM</li>
<li>&hellip;</li>
</ol>
<p>这种模式下，内存带宽（Memory Bandwidth） 成为了绝对瓶颈。算子融合的核心思想就是：将数据一次性搬运到工厂，完成一系列计算（乘、加、Tanh 等）后，再统一写回仓库。</p>
<h3 id="实例gelu">实例：GeLU</h3>
<p>以 GeLU 激活函数为例，其数学表达式包含乘法、加法、Tanh 等多个操作。</p>
<ul>
<li>Manual 实现：使用 PyTorch 基础算子拼凑（<code>x * 0.5 * ...</code>），每次操作都会触发一次 Kernel Launch 和全局显存读写。</li>
<li>Fused 实现：PyTorch 官方的 <code>F.gelu</code> 是经过融合的，只触发一次 Kernel Launch。</li>
</ul>
<p>性能测试显示，Fused 版本比 Manual 版本快得多（源码中约为 7-8 倍差距），且 Profiling 结果显示 Manual 版本充斥着大量琐碎的 Kernel 调用。</p>
<h2 id="cuda-kernels打开黑盒">CUDA Kernels：打开黑盒</h2>
<p>为了追求极致性能，我们可以直接编写 CUDA C++ 代码。</p>
<h3 id="执行模型">执行模型</h3>
<p>CUDA 的执行层级映射了硬件结构：</p>
<ul>
<li>Grid：对应整个计算任务，由多个 Thread Block 组成。</li>
<li>Thread Block：对应一个 SM，块内线程可以共享 Shared Memory 并同步。</li>
<li>Thread：最小执行单元，处理单个数据点。</li>
</ul>
<h3 id="代码与限制">代码与限制</h3>
<p>通过 <code>torch.utils.cpp_extension.load_inline</code> 可以方便地在 Python 中内联 CUDA 代码。虽然手动管理线程索引（<code>blockIdx</code>, <code>threadIdx</code>）和内存能够带来性能收益，但其开发门槛极高：</p>
<ul>
<li>必须手动处理内存合并访问（Coalescing）。</li>
<li>必须手动管理 Shared Memory 的数据搬运。</li>
<li>代码冗长且容易出错（Off-by-one error）。</li>
</ul>
<h2 id="tritonpython-时代的-gpu-编程">Triton：Python 时代的 GPU 编程</h2>
<p>OpenAI 于 2021 年推出的 Triton 旨在降低 GPU 编程门槛。它引入了 Block-level 的编程抽象，让开发者关注“数据块”而非“单个线程”。</p>
<h3 id="triton-vs-cuda">Triton vs CUDA</h3>
<p>Triton 编译器自动处理了许多 CUDA 中需要手动优化的痛点：</p>
<table>
  <thead>
      <tr>
          <th>特性</th>
          <th>CUDA</th>
          <th>Triton</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>内存合并访问 (Coalescing)</td>
          <td>手动</td>
          <td>自动</td>
      </tr>
      <tr>
          <td>共享内存管理 (Shared Mem)</td>
          <td>手动</td>
          <td>自动</td>
      </tr>
      <tr>
          <td>SM 内部调度</td>
          <td>手动</td>
          <td>自动</td>
      </tr>
      <tr>
          <td>SM 间调度</td>
          <td>手动</td>
          <td>手动</td>
      </tr>
  </tbody>
</table>
<h3 id="实现-gelu">实现 GeLU</h3>
<p>在 Triton 中，我们通过 <code>tl.program_id</code> 获取 Block ID，并计算出该 Block 需要处理的数据指针偏移量。计算过程完全向量化，代码非常接近 Python 原生写法。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-Python" data-lang="Python"><span class="line"><span class="cl"><span class="nd">@triton.jit</span>
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">triton_gelu_kernel</span><span class="p">(</span><span class="n">x_ptr</span><span class="p">,</span> <span class="n">y_ptr</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># 1. 计算当前 Block 处理的数据范围</span>
</span></span><span class="line"><span class="cl">    <span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">offsets</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 2. 加载数据 (Triton 自动处理合并访问)</span>
</span></span><span class="line"><span class="cl">    <span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 3. 计算 (会被编译为高效的 PTX 指令)</span>
</span></span><span class="line"><span class="cl">    <span class="n">y</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="o">...</span> <span class="p">)</span> 
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1"># 4. 写回</span>
</span></span><span class="line"><span class="cl">    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>查看生成的 PTX 汇编代码可以看到，Triton 编译器自动进行了 Thread Coarsening（线程粗化），即一个线程可能处理多个元素以提高指令级并行度。</p>
<h2 id="pytorch-compilation">PyTorch Compilation</h2>
<p>这是 PyTorch 2.0 的杀手级特性。</p>
<ul>
<li>原理：通过 <code>torch.compile(model)</code>，PyTorch 会捕获计算图，分析算子间的依赖关系，并自动调用 Triton 后端生成融合后的 Kernel。</li>
<li>效果：在 GeLU 的例子中，<code>torch.compile</code> 生成的 Kernel 性能与手写 Triton 几乎一致，远超 Manual 实现。</li>
<li>Profiling：在 Profiler 中可以看到类似 <code>triton_poi_fused_add_mul_tanh...</code> 的名称，这标志着自动融合生效了。</li>
</ul>
<h2 id="进阶计算softmax-与-matmul">进阶计算：Softmax 与 Matmul</h2>
<h3 id="triton-softmaxreduce-操作">Triton Softmax：Reduce 操作</h3>
<p>Softmax 是一个典型的 Row-wise 操作：$y_i = \frac{e^{x_i}}{\sum e^{x_j}}$。</p>
<ul>
<li>朴素实现：需要多次遍历显存（求 Max -&gt; 减 Max -&gt; 求 Exp -&gt; 求 Sum -&gt; 除法）。对于 $M \times N$ 的矩阵，读写次数高达 $5MN + M$。</li>
<li>Triton 实现：
<ol>
<li>每个 Block 处理矩阵的一行（Row）。</li>
<li>将整行数据加载到 SRAM。</li>
<li>在 SRAM 中完成 Max、Exp、Sum 的计算（利用 <code>tl.max</code>, <code>tl.sum</code>）。</li>
<li>写回结果。</li>
</ol>
<ul>
<li>收益：显存读写降低至 $MN$ 次，实现数倍加速。</li>
</ul>
</li>
</ul>
<h3 id="triton-matmul分块">Triton Matmul：分块</h3>
<p>矩阵乘法（$C = A \times B$）是计算密集型任务，但对于大矩阵，显存带宽依然可能成为瓶颈。</p>
<ul>
<li>朴素做法：计算 $C$ 的每个元素都需要读取 $A$ 的一行和 $B$ 的一列，导致大量重复读取。</li>
<li>Tiling 策略：
<ol>
<li>将 $A$ 和 $B$ 分割成小块（Tiles）。</li>
<li>将 Tiles 加载到 Shared Memory 中。</li>
<li>复用 Shared Memory 中的数据计算出 $C$ 的一部分局部和。</li>
</ol>
</li>
</ul>
<p>通过 Tiling，我们将对慢速 DRAM 的访问转换为了对快速 Shared Memory 的访问。此外，Triton 还支持 L2 Cache 优化（Grouped Ordering），通过调整 Block 的执行顺序，尽可能提高 L2 缓存的命中率。</p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第五讲：GPUs</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-5-gpus/</link>
      <pubDate>Mon, 12 Jan 2026 08:54:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-5-gpus/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;TL;DR&lt;/strong&gt; 本讲是 CS336 系列笔记的第五讲。本讲从 GPU 的设计理念出发，梳理了 SM/SP 计算架构与内存层级体系，并结合 Roofline 模型，重点解析了访存合并、算子融合、重计算及分块（Tiling）等核心优化策略。此外，还应用上述策略，粗略推导了 FlashAttention 如何利用 Online Softmax 技术打破显存带宽瓶颈，实现 IO 感知的极致加速。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p><strong>TL;DR</strong> 本讲是 CS336 系列笔记的第五讲。本讲从 GPU 的设计理念出发，梳理了 SM/SP 计算架构与内存层级体系，并结合 Roofline 模型，重点解析了访存合并、算子融合、重计算及分块（Tiling）等核心优化策略。此外，还应用上述策略，粗略推导了 FlashAttention 如何利用 Online Softmax 技术打破显存带宽瓶颈，实现 IO 感知的极致加速。</p>
</blockquote>
<h2 id="gpu-架构">GPU 架构</h2>
<h3 id="cpu-与-gpu-设计理念">CPU 与 GPU 设计理念</h3>
<p>从架构上看，CPU 有一个很大的控制单元，而 GPU 中计算单元占主导。这揭示了二者在设计理念上的区别：CPU 致力于降低执行延迟，GPU 致力于提高计算吞吐量。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260112091305.png"></p>
<h3 id="gpu-计算单元架构">GPU 计算单元架构</h3>
<p>如下图所示，GPU 有很多流式多处理器 SM 单元， SM 单元是执行编程模型中的“Block”的执行载体，其具备控制单元，根据资源允许，一个 SM 会并发驻留多个 Block 块；单个 SM 由很多流式处理器 SP 构成，每个 SP 代表编程模型中的一个“thread”，SPs 在不同的数据上执行相同的指令。<br>
<img alt="GPU 计算单元解剖图" loading="lazy" src="https://pics.zhouxin.space/20260112091518.png"></p>
<h3 id="gpu-内存架构">GPU 内存架构</h3>
<p>一言以蔽之，内存离 SM 越近，它的速度就越快，延迟就越小。L1 和共享内存在 SM 内部，其访存延迟最小；L2 缓存在片上，其延迟是共享内存的 10x；全局显存是 GPU 核心邻近的存储芯片，其延迟最大。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260112093312.png"></p>
<h3 id="gpu-的执行模型">GPU 的执行模型</h3>
<p>GPU 的执行模型分为 Block、Thread、Warp 三个级别：</p>
<ul>
<li>Block 由多个线程束 Warp 组成，具有独立的显存</li>
<li>Warp 由 32 个连续的线程组成，线程以线程束为单位进行调度</li>
<li>Thread 实际的执行单元，线程对不同的数据执行相同的指令，即所谓的 SIMT 模型<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260112093926.png"></li>
</ul>
<h3 id="gpu-编程模型的优势">GPU 编程模型的优势</h3>
<ul>
<li>很容易扩大规模——堆叠 SMs 即可</li>
<li>得益于 SIMT 模型，容易（？）编程</li>
<li>线程是轻量化的，切换线程的开销很小</li>
</ul>
<h3 id="一些趋势">一些趋势</h3>
<ul>
<li>GPU 算力在矩乘上做了高度特化<br>
从下图可以看出，NVIDIA 的 GPU 在发展过程中矩乘和非矩乘的算力提升幅度差异巨大。这意味着我们在构造神经网络时要尽可能使用基于矩乘的算子才能获得最大的硬件收益。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601121952861.webp"></li>
<li>算力的发展速度快于通信和存储<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601121953561.webp"></li>
</ul>
<h2 id="让机器学习任务在-gpu-上跑的更快">让机器学习任务在 GPU 上跑的更快</h2>
<h3 id="roofline-模型">Roofline 模型</h3>
<p>以计算强度为横轴，吞吐量为纵轴，可以绘制出形如屋顶的曲线，这条曲线可以显示不同计算密度下性能的瓶颈。在上升阶段，瓶颈在于访存；在水平线上，瓶颈在于计算。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122117367.webp"></p>
<h3 id="技巧零控制流分歧">技巧零：控制流分歧</h3>
<p>GPU 上同一个线程束内执行相同的指令，因此在控制语句中如果同一个线程束中的条件值不一样，那么这些线程就会依次进入所有必要的分支，这相当影响这个线程束的执行性能。这点在之前学习 PMPP 时有提到，参阅：<a href="https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-1/#45-control-divergence-%E6%8E%A7%E5%88%B6%E6%B5%81%E5%88%86%E6%AD%A7">Programming Massively Parallel Processors A Hands-on Approach 4th Edition 学习笔记 Part 1 | 周鑫的个人博客</a></p>
<p>这是本节中唯一一个与内存无关的技巧。</p>
<h3 id="技巧一低精度计算">技巧一：低精度计算</h3>
<p>低精度的数据类型占据的空间更小，随之而来的就是更高的计算密度。如下所示，低于 ReLU 算子，将数据类型从 FP32 切换为 FP16，能够将计算密度提升一倍。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122202833.webp"></p>
<p>低精度存在舍入导致的精度问题，因此在 Tensor Core 中矩乘会使用全精度的累加器来计算，从而同时获得低精度低内存带宽和高精度计算精度的收益。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122213041.webp"></p>
<h3 id="技巧二算子融合">技巧二：算子融合</h3>
<p>如下图左，当有很多算子的时候，如果不做算子融合，就需要反复对同一块数据（作为输入输出）进行搬运。此时，可以将多个小算子融合成一个大算子，一次搬运，多次计算，只在必要时才对数据进行搬运。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122217816.webp"></p>
<h3 id="技巧三recomputation">技巧三：Recomputation</h3>
<p>在依次前向计算过程中，我们需要保存每一层的输出以便在反向中计算梯度。使用重计算技术我们可以在求反向过程中重新计算一遍前向的激活值，从而减少数据搬运。如下所示，一个三层的 Softmax 的前向 + 反向需要 8 次访存，应用重计算后只需要 5 次，访存需求减少 38%。这个技术与模型训练过程中减少显存占用的重计算是同一个技术，只是二者的出发点不同，一个为了加速，一个为了节省显存。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122232751.webp"></p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122235415.webp"></p>
<h3 id="技巧四内存合并访问">技巧四：内存合并访问</h3>
<p>这也是一个在 PMPP 中重点介绍的技术：<a href="https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-1/#61-memory-coalescing-%E5%86%85%E5%AD%98%E5%90%88%E5%B9%B6%E8%AE%BF%E9%97%AE">Programming Massively Parallel Processors A Hands-on Approach 4th Edition 学习笔记 Part 1 | 周鑫的个人博客</a>。<br>
DRAM 的物理结构决定了其支持突发访存，即当访问某个位置的元素时，其周围连续的元素也会被一起读取。</p>
<p>为了充分利用突发访存的特性，CUDA 会自动将线程束中的多个线程连续的访存指令转换为突发访存指令，即如果线程束中 0-31 号线程的同一个访存指令访问的目标是全局内存中连续的 32 个位置，则该访存指令将通过突发访存来实现。具体请参考之前的学习笔记。</p>
<h3 id="技巧五分块">技巧五：分块</h3>
<p>分块的动机是减少对全局内存的重复访存，将这些重复访存合并对片上内存的重复访问。</p>
<p>如下图所示，在朴素的矩阵实现中，每 P(0,0) 这个元素会被多个线程反复访问，并且这些线程访问的模型也不是合并访问的模式，这会大大降低执行效率。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122255650.webp"></p>
<p>将输入矩阵按照 2x2 分块以后，就可以将计算设计的小块使用合并访问的访存模式提前加载到共享内存中，并重复利用<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601122259611.webp"></p>
<p>如果我们按照步长 T 对 NxN 的矩阵分块，在非分块的情况在每个元素要被加载 N 次，而在分块的情况下每个元素只要被加载 N/T 次，这是 T 倍的加速比，相当可观。</p>
<p>在实际应用中，分块大小可能无法被原张量形状整除，这就会导致资源利用低效。分块大小受到很多因素的约束：</p>
<ul>
<li>内存合并访问</li>
<li>共享内存大小</li>
<li>原始张量的形状<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260113091549.png"></li>
</ul>
<p>分片的另一大挑战是内存对齐，如果输入张量的形状不是很完美，分片后每个 Block 需要访问的数据横跨两个不同的突发访存块中，那么访存次数相比对齐的情况就会加倍。这类问题的解决方案是通过 padding 手段将其补全为一个内存对齐的布局。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260113092038.png"></p>
<h2 id="矩乘性能图">矩乘性能图</h2>
<p>下图展示了矩阵乘法计算吞吐量随矩阵形状的变化情况，本节将应用前面所讲的理论对此进行解释。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260113093125.png"></p>
<ul>
<li>在 size 增加的前期，性能瓶颈在于访存，此阶段吞吐量随计算密度的提升而提升；在后期呈现出波浪线的阶段，性能瓶颈在于计算，即 roof line 的水平阶段。</li>
<li>曲线之间的性能差异来自内存对齐和内存合并访问</li>
<li>在波浪线阶段，以 K=2 为例，从 1792 到 1793 出现了陡峭的性能下降。这是因为 GPU 在计算矩乘是一般以 256x128 进行分块，1792 需要 98 个 SM，而 1793 需要 120 个 SM，A100 一共只有 108 个 SM，无法一轮全部执行所有的 120 个 Block 所以性能陡峭下降。</li>
</ul>
<h2 id="flashattention">FlashAttention</h2>
<p>FA 通过使用分块和重计算技术，显著提升了 Attention 的计算速度并降低了对 HBM 的访存需求：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260113094601.png"></p>
<h3 id="分块一对-kqv-矩乘分块">分块一：对 KQV 矩乘分块</h3>
<p>第一个分块操作是对 KQV 之间的矩阵乘法进行分块计算，这是之前讲过的常规操作：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260113094636.png"></p>
<h3 id="分块二softmax">分块二：Softmax</h3>
<p>标准的 Softmax 需要至少对数据进行三次遍历：</p>
<ul>
<li>遍历所有数据，找到最大值</li>
<li>再次遍历数据，计算分母</li>
<li>第三次遍历数据，计算结果</li>
</ul>
<p>其痛点在于这个过程无法分块，需要先统计出全局最值之后才能计算最后的结果。</p>
<p>而 Online Softmax 则引入了一个技巧，在不知道全局最值的情况在边扫描数据，边维护 Softmax 统计量。具体来说，设我们处理到第 $j-1$ 个元素，当前的局部最大值是 $m_{j-1}$，当前的局部和是 $d_{j-1} = \sum_{k=1}^{j-1} e^{x_k - m_{j-1}}$。现在进来一个新的元素 $x_j$，需要进行两个操作：</p>
<ul>
<li><strong>更新最大值</strong>：新的最大值 $m_j = \max(m_{j-1}, x_j)$。</li>
<li><strong>更新分母和</strong>：我们需要把旧的 $d_{j-1}$ 修正到基于新最大值 $m_j$ 的尺度上。<br>
利用指数运算性质：$e^{x - m_{new}} = e^{x - m_{old} + m_{old} - m_{new}} = e^{x - m_{old}} \times e^{m_{old} - m_{new}}$，所以新的和 $d_j$ 等于：</li>
</ul>


<div>$$

d_j = \underbrace{d_{j-1} \times e^{m_{j-1} - m_j}}_{\text{旧的和，修正系数}} &#43; \underbrace{e^{x_j - m_j}}_{\text{当前新项}}

$$</div>

<p>基于上述数学变换，我们可以在每个 tile 内计算局部最大值和求和值，最后再讲结果合并以修正误差。</p>
<h3 id="fa">FA</h3>
<p>下图展示了 FA 的计算过程，首先对 QK 进行分块矩阵，算完立刻做 Online Softmax，并得到一组输出，最后再做一次修正得到最终的结果。这里只是理论介绍，忽略了很多细节，后面 Lab 2 中应该是需要手搓一个 FA 的。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20260113100104.png"></p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第四讲：MoE</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-4-moe/</link>
      <pubDate>Sun, 11 Jan 2026 11:53:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-4-moe/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;TL;DR&lt;/strong&gt; 本讲 CS336 系列笔记的第四讲。本讲梳理了 MoE 架构利用稀疏激活实现“高效扩参”的核心机制，并结合 DeepSeek 系列模型的演进路线，重点解析了细粒度专家、共享专家及无辅助损失负载均衡等策略，如何解决大规模训练中的路由坍塌与稳定性难题。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p><strong>TL;DR</strong> 本讲 CS336 系列笔记的第四讲。本讲梳理了 MoE 架构利用稀疏激活实现“高效扩参”的核心机制，并结合 DeepSeek 系列模型的演进路线，重点解析了细粒度专家、共享专家及无辅助损失负载均衡等策略，如何解决大规模训练中的路由坍塌与稳定性难题。</p>
</blockquote>
<h2 id="引入">引入</h2>
<h3 id="什么是-moe">什么是 MoE</h3>
<p>MoE 模型指的是将 Transformer 架构中的 FFN（MLP）模块替换为多个稀疏的 FFN 模块（专家），前向计算时每次只稀疏激活一部分专家，由此实现在不增加计算量的情况加增加模型参数。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111209985.webp"></p>
<p>由于 MoE 架构同计算量在模型参数更多，最终学习到的模型的性能也更好，因此 2025 以来主流模型均转向了 MoE 架构。</p>
<p>当然，天下没有免费的午餐。MoE 的代价是整个模型尺寸变得更大，需要更多的存储空间，以及在并行化方面带来了更大的系统复杂度。</p>
<h3 id="为什么-moe-变成主流">为什么 MoE 变成主流</h3>
<ul>
<li>相同的计算量下，参数越多，模型性能越好。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111329700.webp"></li>
<li>MoE 训练更快。相比 Dense 架构，MoE 架构达到相同的 loss 只需要 1/7 的时间。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111330464.webp"></li>
<li>天然适合并行化。将 MoE 的不同专家部署到不同的设备上是一个很自然的想法。</li>
</ul>
<h3 id="moe-的局限">MoE 的局限</h3>
<ul>
<li>
<p>依赖于基础设施建设<br>
MoE 模型尽管在计算方面有优势，但是模型本身很大，这就要求训练团队需要有一套强大的基建能够支撑起 MoE 模型在训练过程中对于多硬件、多数据、高度并行化的需求，这对于中小团队来说是很困难的。</p>
</li>
<li>
<p>训练过程依赖经验且不稳定<br>
MoE 的路由过程是不可微分的，并且路由策略的选择对训练结果影响巨大。因此相比稠密模型其训练过程更容易出现不稳定性，这对开发者的经验和技巧提出了更高的要求。</p>
</li>
</ul>
<h3 id="架构">架构</h3>
<p>通常 MoE 指的是下图左边的使用 MoE 替换 Transformer 中的 MLP 层，也有工作尝试使用相同思路将 MHA 替换为 MoE 的版本，但是这并不主流，并且此类模型的训练难度更大。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111356143.webp"></p>
<h2 id="moe-的变种">MoE 的变种</h2>
<p>MoE 的各类变种可以归结为三类：路由算法、专家数量和损失函数。</p>
<h3 id="路由算法">路由算法</h3>
<h4 id="总览">总览</h4>
<p>路由算法指的是决定每个 token 去往哪个专家的路由机制。可以归结为三类：</p>
<ul>
<li>token 选择专家：以 token 为主体，每个 token 独立挑选最匹配的专家</li>
<li>专家选择 token：以专家为主体，每个专家从所有输入中挑选最匹配的 token</li>
<li>通过全局优化路由：全局角度最优地分配 Token 到专家的任务，通常涉及复杂的数学优化问题<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111405713.webp"></li>
</ul>
<h4 id="常见算法">常见算法</h4>
<p>常见的路由算法是 Top-k 和哈希算法，后者使用一个哈希函数将不同的 token 映射到专家，常常作为 baseline 来评估其它路由算法。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111450688.webp"></p>
<h4 id="top-k-算法">Top-K 算法</h4>
<p>如下图所示（从下往上），在 Top-K 路由算法中，有一个可学习的参数 e，首先将输入 u 与 e 做内积，得到二者相似度，对相似度做 Softmax 得到每个专家的分数，然后过滤出前 k 个专家作为门控，再根据每个专家的分数对每个（前 k 个）专家的输出做加权求和。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111459726.webp"></p>
<h3 id="fine-grained-expert-segmentation-shared-expert-isolation">Fine-grained Expert Segmentation 、Shared Expert Isolation</h3>
<p>国产大模型在专家路由架构方面有两个技术趋势。下图 (a) 是传统的 Tok-2 路由算法；(b) 是细粒度专家分割，即将单个专家规模缩小一半，专家总数增加一倍，激活专家数增加一倍，总计算量没有增加，但是专家数的增加能够带来更好的性能；(c) 是共享专家隔离，引入一个始终被激活的专家，从可解释性的角度来说，这个专家可以解决知识冗余的问题，即这个始终被激活的专家负责记忆那些“大家都得会”的知识。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111529032.webp"></p>
<p>从 DeepSeek 技术报告中的消融实验可以看出，上述架构的演进确实能够提升模型的性能。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111534337.webp"></p>
<h2 id="训练-moe">训练 MoE</h2>
<blockquote>
<p>梯度实际上是可以顺着被选到的专家流动下来的。这里所谓的不可微分实际指的是只有被选到的专家才有梯度，但是选不到的专家很容易陷入饿死的状态，即强者越强，弱者越弱，永远得不到更新。</p>
</blockquote>
<p>专家路由是一个离散不可微分的过程，对此有三种解决方案：</p>
<ul>
<li>强化学习</li>
<li>随机扰动</li>
<li>启发式负载均衡损失</li>
</ul>
<h3 id="强化学习">强化学习</h3>
<p>强化学习可以对离散变量进行建模，理论上其完美适配这个问题，但是实验数据表明这个方法相比其它方案、甚至相比基线模型其性能优势并不显著，但是引入强化学习的同时还带了梯度的不确定性和整体复杂度的显著提升，因此目前没有工作在大规模训练上使用强化学习解决 MoE 的训练的问题。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111814270.webp"></p>
<h3 id="随机扰动">随机扰动</h3>
<p>随机扰动这个方案通过在路由的分数中加入正态分布的噪声和一个可学习的线性层控制整体噪声的幅度。随机性因素的加入使得模型有着更好的鲁棒性，但是此类方案在 LM 后期也由于其潜在的不稳定性被抛弃，业界投向效果更好的启发式方法。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111820274.webp"></p>
<h3 id="启发式负载均衡损失">启发式负载均衡损失</h3>
<p>这是在 Switch Transformer 中提出的方案，引入一个辅助的损失函数，这个函数是对每个专家的 $f_i$ 和 $P_i$ 的积求均值得到，其中 $f_i$ 表示专家 $i$ 在一个 batch 中实际被选中的概率，$P_i$ 表示专家 $i$ 在路由算法中被分配到的概率总和的平均值。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111846484.webp"></p>
<blockquote>
<p>引用自 Gemini 的答疑解惑：</p>
<p>为什么不直接优化 $f_i$ 使之趋向平均分布？<br>
因为 $f_i$ 是由离散路由算法计算出来的，是一个不可微的值。</p>
<p>为什么不直接优化 $P_i$ 使之趋向平均分布？<br>
因为 Top-K 算法只对数值的绝对排序敏感，只优化 $P_i$ 仍无法杜绝马太效应——只要模型让最强大的专家的分数略高于平均分布即可。</p>
<p>总结，这个 loss 设计得非常巧妙，它实际上是在说：如果一个专家实际上很忙（$f_i$ 大），那我们就惩罚它的预测概率（$P_i$），以此来减少它未来被选中的机会。</p>
</blockquote>
<p>在 DeepSeek v1-2 中，还有一个结构类似的辅助损失函数，其作用是在设备之间实现负载均衡，将频率统计从按照专家统计改为按照设备统计即可：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111853989.webp"></p>
<p>如果没有负载均衡机制，可以看到模型的性能不如上了负载均衡的情况。此外，如果咩有负载均衡机制，只有两个专家被路由，其他专家专家都饿死了。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111916833.webp"></p>
<h3 id="无需辅助损失的启发式负载均衡">无需辅助损失的启发式负载均衡</h3>
<p>在 DeepSeek V3 中提出了一种不需要引入辅助损失的负载均衡手段。其在 Softmax 计算得到的注意力分数中额外增加了一个偏置项，这个偏置项与这个专家的实际负载负相关，从而实现<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111910416.webp"></p>
<h2 id="moe-的问题">MoE 的问题</h2>
<h3 id="随机性">随机性</h3>
<p>当模型所在的卡无法容纳更多的 token 时，多出来的 token 会被丢弃，这在训练和推理阶段引入了随机性，使得模型的输出可能会在不同的 batch 之间表现不同。</p>
<h3 id="稳定性">稳定性</h3>
<p>路由算法中的 Softmax 计算存在稳定性问题，尤其是在 BF16 下，一点数值扰动或者舍入可能对最终的输出有很大的影响。为此，需要在路由选择计算中采用 FP32 并在必要时引入上一讲中介绍的 z-loss 优化手段。</p>
<p>可以看到，z-loss 的引入可以有效抑制训练过程中损失函数的尖峰。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111943028.webp"></p>
<h3 id="微调">微调</h3>
<p>MoE 模型在微调过程中很容易过拟合。一种解决思路是调整模型架构，在模型架构中交替使用 MoE 和 Dense 架构，在微调时只对 Dense 部分微调。另一种思路是大力出奇迹，增加数据量。</p>
<h2 id="upcycling">Upcycling</h2>
<p>一种低成本训练 MoE 模型的方式是先训练一个 Dense 模型，然后将 Dense MLP 拷贝 n 份来构造出一个 MoE 架构，并以此为起点开始训练。通过这一方式可以高效地训练出一个 MoE 模型。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111949424.webp"></p>
<h2 id="deepseek-演进路线">DeepSeek 演进路线</h2>
<h3 id="deepseek-moe-v1">DeepSeek MoE V1</h3>
<p>DeekSeek V1 是一个 16B 激活 2.8B 的模型，在架构方面选择了 2 共享专家 +64 细粒度专家，路由算法为 Top-6。使用了标准的辅助负载均衡损失函数。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111956303.webp"></p>
<h3 id="deepseek-v2">DeepSeek V2</h3>
<p>V2 是一个 236B 激活 21B 的模型，采用 2 共享专家，160 细粒度专家，路由算法为 Top-6 。专家切分的越细，激活的专家就越多，因此引发的通信成本就越大。为此，在路由算法方面他们采取了两步走的策略，先选取 Tok-M 个设备，然后在这些设备上再进行路由，从而控制整体的通信成本。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111958224.webp"></p>
<h3 id="deepseek-v3">DeepSeek V3</h3>
<p>V3 是一个 631 B 激活 37B 的模型，1 共享专家，258 细粒度专家，路由算法是 Top-8。在路由选择中，计算分数的函数从 Softmax 替换为 Sigmoid。在损失函数方面，其采用了之前提到的“无需辅助损失的启发式负载均衡”，同时采用了一个 sequence 级别的负载均衡损失函数，以确保 token 在推理阶段能够均匀派发给不同的设备。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601112010380.webp"></p>
<h3 id="mla--mtp">MLA &amp; MTP</h3>
<p>这部分老师讲的比较仓促，直接参考知乎上大佬的解读文章：<br>
<a href="https://zhuanlan.zhihu.com/p/16730036197">deepseek技术解读(1)-彻底理解MLA（Multi-Head Latent Attention）</a><br>
<a href="https://zhuanlan.zhihu.com/p/18056041194">deepseek技术解读(2)-MTP（Multi-Token Prediction）的前世今生</a></p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第三讲：架构与超参数</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-1-architectures-hyperparameters/</link>
      <pubDate>Sat, 10 Jan 2026 15:53:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-1-architectures-hyperparameters/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;TL;DR&lt;/strong&gt; 本讲梳理了现代 LLM 架构设计的“事实标准”（Pre-Norm + RMSNorm + RoPE），并从系统视角解析了 GQA 与滑动窗口机制如何通过优化 KV Cache 访存，解决推理阶段的算术强度劣化问题。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p><strong>TL;DR</strong> 本讲梳理了现代 LLM 架构设计的“事实标准”（Pre-Norm + RMSNorm + RoPE），并从系统视角解析了 GQA 与滑动窗口机制如何通过优化 KV Cache 访存，解决推理阶段的算术强度劣化问题。</p>
</blockquote>
<p>如下图所示，LLM Dense 模型的架构演进始于 2017 年。经历早期的百家争鸣与多方探索后，目前已呈现出明显的收敛态势。本讲将总结当前大模型在架构设计与超参数选择上的经验与技巧，旨在帮助我们立足前人成果，少走弯路。</p>
<p>具体来说，本讲包括：</p>
<ul>
<li>模型架构的变种
<ul>
<li>激活层和全连接层的选择</li>
<li>Attention 的变种</li>
<li>位置编码</li>
</ul>
</li>
<li>超参数的选择
<ul>
<li><code>ff_dim</code></li>
<li>多头的 <code>head_dim</code> 之和是否一定要与模型的特征维度相等</li>
<li>词汇表大小</li>
</ul>
</li>
<li>提升训练稳定性的技巧</li>
<li>Attention 架构<br>
<img alt="LLM 模型架构演进路线" loading="lazy" src="https://pics.zhouxin.space/202601101628230.webp"></li>
</ul>
<h2 id="模型架构">模型架构</h2>
<h3 id="pre-vs-post-norm">Pre-vs-Post Norm</h3>
<p>在标准 Transformer 架构中（下图左），Norm 块被安排在每个残差加法结束后进行，即 $x_{t+1}=\text{LayerNorm}(x_t+F(x_t))$，这是 Post-Norm。而在 Pre-Norm 中（下图右），Norm 被放在残差内的子层之前，即 $x_{t+1}=x_t+F(\text{LayerNorm}(x_t))$。<br>
<img alt="Pre-vs-Post norm" loading="lazy" src="https://pics.zhouxin.space/202601101702339.webp"></p>
<p><strong>从 2024 年起，使用 Pre Norm 替代 Post Norm 成为业界广泛共识， Pre Norm 在训练稳定性上优于 Post Norm。</strong> 在 Post-Norm 中，需要引入一些提升稳定性的 tricks ，例如 Warm Up ，来确保训练稳定进行；而在 Pre Norm 对此则没有那么敏感，并且最终的模型性能与 Post-Norm 不相上下。</p>
<p>为啥 Pre Norm 更好？主流的解释是在 Pre Norm 的残差路径上不存在 Norm 模块，梯度可以通过恒等映射直接从顶层传播到底层，从而避免了梯度消失或者爆炸问题。当下的模型越来越深，训练的稳定性就是刀乐，业界普遍转向 Pre Norm 也是个很符合直觉的现象。</p>
<p>既然 Pre Norm 是在残差的子层计算前进行 Norm，是不是有工作尝试在残差子层计算后进行 Norm 呢？有的，包有的。老师将其成为 Double Norm，即在残差子层计算前后都进行一次 Norm，近期的 Gemma 2 就采用这种架构。<br>
<img alt="Double Norm" loading="lazy" src="https://pics.zhouxin.space/202601101831275.webp"></p>
<h3 id="layernorm-vs-rmsnorm">LayerNorm vs RMSNorm</h3>
<p>LayerNorm 与 RMSNorm 公式如下，前者在特征维度进行标准化后按照两个可学习的向量 $\gamma$ 和 $\beta$ 缩放和偏移，后者省去了对均值偏移的操作，</p>


<div>$$

\begin{align*}
\text{LayerNorm(x)} &amp;= \frac{x - \mu}{\sigma} \cdot \gamma &#43; \beta \\
\text{RMSNorm}(x) &amp;= \frac{x}{\sqrt{||x||^2_2&#43;\epsilon}} \cdot \gamma 
\end{align*}

$$</div>

<p><strong>近年来，业界普遍从 LayerNorm 转向了 RMSNorm。</strong> 这一转变可以归结为：RMSNorm 计算量和参数量都更少。如下图所示，就理论计算量而言，Norm 模块占比只有 0.17% 并不高，但是由于需要搬运显存，导致了在实际运行中占比高达 25.5%，省略对均值的偏移可以降低对显存读写的需求，从而提升实际运行速度。与此同时，模型的精度并没有因为这一转换而降低。</p>
<p><img alt="LLM 各模块计算量和运行时占比" loading="lazy" src="https://pics.zhouxin.space/202601101852940.webp"></p>
<p><strong>在 Norm 模块，另一潮流是现代模型倾向于扔掉偏置项。</strong> 不只是 RMSNorm 的偏置项，还包括所有线性层中的偏置项。这一转变的底层动机与 RMSNorm 一致：减少内存开销。此外，尽管还说不清楚，但是确实有现象表明丢弃偏置项可以提升学习过程的稳定性。</p>
<h2 id="激活函数">激活函数</h2>
<p>激活函数有很多很多，采用不同的激活函数，可以衍生出各种 MLP 的变种。这里直接用课程里的 PPT 来介绍：</p>
<p>ReLU 是最简单的激活函数，但是在负值区域没梯度，可能导致神经元饿死。GeLU 在负值区有微小梯度仍可以继续训练。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601101918017.webp"><br>
*GLU 系列的 MLP 层修改了第一个全连接层，将其额外逐元素乘上一个 $xV$，实现所谓“门控机制”。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601101921585.webp"><br>
根据激活函数的不同，产生了 ReGLU、GeGLU 和 SwiGLU 的变种，其中门控单元中特征维度长度为模型特征维度的 2/3 以确保整体参数量和计算量不变。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601101923241.webp"><br>
实验表明 GLU 系列确实具有更好的表现，但是也不是必须的，例如 GPT-3 就没有使用这一激活函数。</p>
<h3 id="并行与串行架构">并行与串行架构</h3>
<p>传统的 Transformer 架构属于串行架构，输入依次经过 LayerNorm、 Attention 、残差、MLP 和 MLP 模块；与之相对的变种是并行架构，输入同时经过 Attention 和 LayerNorm 模块，然后将两个模块的输出相加并送入 MLP。这一设计使得模型在大规模训练时能够提升一定的训练速度。</p>
<h3 id="位置编码">位置编码</h3>
<p>在 LM 早期，有各种位置编码函数，包括 Sine、绝对位置编码、相对位置编码：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601101944132.webp"><br>
有了以上这些函数，为啥还需要 RoPE 呢？一个完美的相对位置编码应该满足如下的数学定义：对于一个位置嵌入函数 $f(x, i)$，我们希望两个被位置嵌入后的向量的点积（也就是注意力分数）只和这两个向量的相对位置有关，也就是：</p>


<div>$$

\text{InnerProduct(f(x, i), f(y, j)) = g(x, y, i-j)}

$$</div>

<p>上述性质也可以被描述为平移不变性。</p>
<p>为了实现平移不变性，RoPE 的思路是将一个每个 token 都根据其位置旋转对应角度，这样在计算二者的夹角时就仅与二者相对位置相关。每个 token 在特征维度构成的空间是高维的，而角度是一个在二维平面上的概念，为此 RoPE 的解决方案是将特征维度两两分组，每组视为一个二维向量进行旋转操作，并且不同组的旋转角度不同，从而同时捕获高频和低频位置信息。</p>
<h3 id="小节">小节</h3>
<ul>
<li>Pre-vs-Post Norm
<ul>
<li>选 PreNorm，更稳定</li>
</ul>
</li>
<li>LayerNorm vs RMSNorm
<ul>
<li>果断选 RMSNorm，更快，模型性能有时更好</li>
</ul>
</li>
<li>激活函数
<ul>
<li>GLU 系列似乎更好，但是差距并不显著</li>
</ul>
</li>
<li>并行与串行架构
<ul>
<li>没有严格的消融实验证明并行更好</li>
<li>并行在计算速度上有一定的优势</li>
</ul>
</li>
<li>位置编码
<ul>
<li>果断选 RoPE，其具备平移不变性</li>
</ul>
</li>
</ul>
<h2 id="超参数">超参数</h2>
<h3 id="mlp-特征数">MLP 特征数</h3>
<p>在 MLP 模块中，会将输入的特征的维度从模型维度 $d_{\text{model}}$ 在内部投射到 MLP 特征数 $d_{\text{ff}}$。对于非 GLU 系列的 MLP 模块，一般将其放大 4 倍；对于 GLU 系列的模块，一般将其放大 8/3 倍。这是目前的共识。</p>
<p>有工作研究过上采样倍数与 loss 的关系，结果显示倍数在个位数这个数量级上效果最好。曾经有工作（T5）选择了 64 的倍数，但是那也是昙花一现，在其后续的 T5 v1.1 中也选择了常见的 2.5 倍作为上采样系数。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102042349.webp"></p>
<h3 id="注意力特征数">注意力特征数</h3>
<p>本节讲的是 <code>num_heads x head_dim</code> 这一组超参数的选择，我将其描述为注意力超参数。</p>
<p>目前业界的共识是注意力特征数等于或者略多于模型特征数。</p>
<p>对于这个参数的选择，有工作认为当固定总特征数后，如果想要提升头数，就不得不降低每个头的特征数，这会导致每个头的表征能力下降，进而降低模型的整体性能。但是在实践中这一效应并不显著，业界还是倾向于使用模型特征数作为注意力的特征数。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102039191.webp"></p>
<h3 id="横纵比">横纵比</h3>
<p>横纵比指的是模型的特征数与模型的深度 <code>num_layers</code> 之间的比值。主流模型选择 1xx 这个区间，这似乎也是一种共识。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102046365.webp"></p>
<p>相关工作也证实了 1xx 这个区间是比较好的一个甜点区，与模型参数量相关性并不显著。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102057643.webp"></p>
<p>横纵比影响着并行策略的选取，比较深的模型适合流水线并行，比较胖的模型适合张量并行。</p>
<h3 id="词汇表数量">词汇表数量</h3>
<p>词表的数量取决于支持的语言数量和具体生产用途，单语言模型在 30-50k 这个量级，多语言模型在 100-250k 这个量级：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102101637.webp"></p>
<h3 id="正则化">正则化</h3>
<p>理论来说，在预训练阶段不需要正则化手段（Dropout、Weight Decay），因为预训练阶段在训练数据上跑一轮，很难产生过拟合。但事实上，大家倾向于使用一些正则化技术：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102104319.webp"></p>
<p>对于这一反直觉的现象，一篇工作对此的解释是：Weight Decay 不是为了防止过拟合，它是一种与学习率相互交互的技巧，将权重变小在某种程度上等价于将学习率变大，这可以让模型在训练后期学习率较小的时候获取一个更小的损失值。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601102110723.webp"></p>
<h2 id="小结">小结</h2>
<ul>
<li>MLP
<ul>
<li>GLU 系列选择 8/3 作为上采样系数</li>
<li>非 GLU 系列选择 4 作为上采样系数</li>
</ul>
</li>
<li>注意力特征数
<ul>
<li>注意力特征数等于模型特征数</li>
</ul>
</li>
<li>纵横比
<ul>
<li>选择 1xx 作为纵横比</li>
</ul>
</li>
<li>正则化
<ul>
<li>使用正则化手段，但是是为了获得更小的损失</li>
</ul>
</li>
</ul>
<h2 id="技巧">技巧</h2>
<p>本节将介绍一些训练过程中提升稳定性的技巧。</p>
<h3 id="softmax">Softmax</h3>
<p>Softmax 计算中由于存在 exp 和除法操作，因此其对数值很敏感，例如除 0 等问题，很容易在训练过程中爆炸。在 Transformer 中有两处 Softmax，一处是模型最后对 logits 进行 Softmax 操作，另一处是 Attention 块中计算注意力得分时的 Softmax 操作。下面针对这两处分别介绍提升稳定性的技巧。</p>
<ul>
<li>Output Softmax<br>
记 Softmax 中的分母项为 Z，为了防止 logits 中 Z 过大，引入了一种叫做 Z-loss 的优化手段，即将 log(Z) 作为正则项加入到损失函数中，迫使 Z 优化到 1 附近，从而避免在大规模计算中的指数运算引发浮点数溢出和梯度异常。</li>
</ul>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601110859836.webp"></p>
<ul>
<li>Attention Score<br>
在 MHA 计算中，一些视觉和多模态模型在 QK 计算注意力分数前分别对 QK 计算一次 LayerNorm，这一技术被称为 QK Norm，能够提升训练过程中的稳定性。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601110902312.webp"></li>
</ul>
<h2 id="attention-机制">Attention 机制</h2>
<h3 id="gqa--mqa">GQA / MQA</h3>
<p>如下图所示，在 LM 的训练阶段，Attention 模型具备比较高的算术强度 Arithmetic Intensity（下图中 k 应该是 d），即在训练阶段是计算受限的，这能够充分发挥 GPU 的计算能力。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601110942375.webp"></p>
<p>但是在推理阶段，我们无法并行地去预测每一个 token，只能逐 token 生成。如下图所示，当生成 t+1 个 token 时，前 t 个 KV 值可以被 Cache，在 qkv 的 projection 阶段，只需要计算 input 中最后一个 token 的对应的 qkv。<br>
<img alt="KV Cache 示意图  图源：https://medium.com/@joaolages/kv-caching-explained-276520203249" loading="lazy" src="https://pics.zhouxin.space/202601111005017.webp"></p>
<p>带有 Cache 的 MHA 中，总的计算量并没有变（训练一次性计算出所有的 QKV，推理分 n 步计算出所有的 QKV），但是总的访存量显著增加，因为需要反复加载 Cache，最终导致算术强度在推理阶段劣化。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111019356.webp"></p>
<p>导致算术强度劣化的原因是 KV Cache 的反复加载，解决问题的思路是减少 KV Cache 的大小。KV Cache 来自 KV Projection，因此一个很自然的想法是减少 KV 的注意力头数。如果将头数减少到 1，就是 MQA（Multi-Query Attention），此时的计算强度随着注意力头数的增加而增加。如果将 KV 的头数降低为 Q 的 1/n，即每一个 KV 对应 n 个 Q，这就是 GQA（Group Query Attention），同样可以提升算术强度。<br>
<img alt="MQA 示意图" loading="lazy" src="https://pics.zhouxin.space/202601111028410.webp"></p>
<h3 id="sparsesliding-window-attention">Sparse/Sliding Window Attention</h3>
<p>在传统 Attention 中，每个 token 会与其之前所有的 token 交互生成注意力分数，这使得 Attention 计算量和 KV Cache 的长度随 N 增长而平方和线性增长。</p>
<p>如下图所示，在滑动窗口注意力中，每个 Token 只与其附近固定数量个 Token 交互，从而使得总计算量随 N 线性增长。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111038347.webp"></p>
<p>对于稀疏注意力，老师没有太多的介绍，Gemini 对此的概括是：为了解决 Sliding Window “目光短浅”的问题，Sparse Attention 引入了更复杂的连接模式，包括：滑动窗口，保留 Sliding Window，用于捕捉局部语法关系；全局节点，定几个特殊的 Token（比如 [CLS] 或者特定的 prompt token），让它们可以看到所有人，所有人也可以看到它们；空洞，像空洞卷积一样，每隔 $k$ 个词看一眼。增加了感受野，但不需要全看。</p>
<h3 id="trick">Trick</h3>
<p>为了同时获取稠密 Attention 全局感受野和稀疏 Attention 减少计算量的收益，现在有一种架构同时使用上述两种 Attention。以 4 个 block 为一个周期，前三个使用稀疏注意力，第四个使用标准注意力。在某些实现中，仅在滑动窗口注意力中使用位置嵌入 RoPE，在全量注意力中不做位置嵌入</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202601111049145.webp"></p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 Lab 1 实验笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lab-1/</link>
      <pubDate>Fri, 24 Oct 2025 10:08:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lab-1/</guid>
      <description>&lt;p&gt;本文记录了 CS336 Lab 1 的实验笔记，整个 Lab 的工作量很大，主要内容包括从头实现 BPE Tokenizer、一系列算子和基于 Transformer 的语言模型，并在此基础上进行大量的调优和消融实验。做完这个 Lab 可以对分词器的实现细节有高细粒度的理解，也能积攒对 Transformer 模型各组件的直观认识。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>本文记录了 CS336 Lab 1 的实验笔记，整个 Lab 的工作量很大，主要内容包括从头实现 BPE Tokenizer、一系列算子和基于 Transformer 的语言模型，并在此基础上进行大量的调优和消融实验。做完这个 Lab 可以对分词器的实现细节有高细粒度的理解，也能积攒对 Transformer 模型各组件的直观认识。</p>
<p>我的实现在 <a href="https://github.com/LittleHeroZZZX/cs366/tree/assignment1">GitHub - LittleHeroZZZX/cs366 at assignment1</a> ，欢迎交流与指正。</p>
<h2 id="tips">Tips</h2>
<p>提高开发效率以及避免踩坑的一些建议：</p>
<ul>
<li>关闭 AI 补全（显著<del>提高</del>降低效率，但是能够了解到更多的细节处理）</li>
<li>使用 Python 类型标注系统，并将 Pylance 类型检查设置为标准，这样能在静态检查出绝大多数类型、参数不匹配的问题</li>
<li>实现 BPE 前首先厘清其中各个数据结构和流程的概念，例如语料库、pre-token、token、预分词，然后再动手</li>
<li>使用 logger 和 tqdm 随时随地打印进度，以便对各个组件耗时和瓶颈组件有个直观的认识</li>
</ul>
<h2 id="bpe-分词器">BPE 分词器</h2>
<p>分词器的整个构建流程包括：</p>
<ul>
<li>词表初始化：构造初始词表，包括 256 个 ASCII 字符和 special tokens</li>
<li>pre-tokenize： 给定语料库，将语料库按照给定正则表达式划分为 pre-token，并统计词频</li>
<li>计算 BPE merges：给定 pre-token，合并出现频次最高相邻 pre-token 对作为一个 token，并不断重复这个过程直至词表达到目标</li>
</ul>
<p>对于分词器，我将其组织为两个类：</p>
<ul>
<li>Pre-Tokenizer：负责将给定语料划分为 pre-token，并统计各个 pre-token 的词频</li>
<li>TokenizerTrainer：基于给定词频结果运行 BPE 算法，并产出 merge pair 和 词表</li>
</ul>
<h3 id="pre-tokenizer">Pre-Tokenizer</h3>
<p>Pre-Tokenizer 类的接口如下所示，其对外提供一个 <code>pre_tokenize</code> 方法为给定字节流生成 pre-token 迭代器，一个 <code>__call__</code> 方法将给定语料库文件转换为 pre-token 出现频次的字典。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">PreTokenizer</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="nd">@staticmethod</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">_merge_pre_token_counts</span><span class="p">(</span><span class="o">*</span><span class="n">pre_token_counts</span><span class="p">:</span> <span class="n">PreTokenCount</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">PreTokenCount</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;Merge multiple PreTokenCount dictionaries into one.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns:
</span></span></span><span class="line"><span class="cl"><span class="s2">            PreTokenCount: The merged PreTokenCount.
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">_process_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chunk</span><span class="p">:</span> <span class="n">Chunk</span><span class="p">,</span> <span class="n">special_tokens</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Token</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">PreTokenCount</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Process a single chunk of text and return the pre-token counts.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Args:
</span></span></span><span class="line"><span class="cl"><span class="s2">            chunk (Chunk): The chunk of text to process.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns:
</span></span></span><span class="line"><span class="cl"><span class="s2">            PreTokenCount: A dictionary-like object mapping pre-tokens to their counts.
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">pre_tokenize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">str_bytes</span><span class="p">:</span> <span class="nb">bytes</span><span class="p">,</span> <span class="n">special_token_list</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Token</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Iterator</span><span class="p">[</span><span class="n">Token</span><span class="p">]:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Pre-tokenize the given bytes string.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Args:
</span></span></span><span class="line"><span class="cl"><span class="s2">            str_bytes (bytes): The input bytes string to pre-tokenize.
</span></span></span><span class="line"><span class="cl"><span class="s2">            special_token_list (list[Token]): The list of special tokens.
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns:
</span></span></span><span class="line"><span class="cl"><span class="s2">            Iterator[Token]: An iterator over the pre-tokens.
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nd">@abstractmethod</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">corpos_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">split_special_token</span><span class="p">:</span> <span class="n">Token</span><span class="p">,</span> <span class="n">special_tokens</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Token</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">PreTokenCount</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Pre-tokenize the given corpus.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Args:
</span></span></span><span class="line"><span class="cl"><span class="s2">            corpos_path (str): Path to the corpus file.
</span></span></span><span class="line"><span class="cl"><span class="s2">            split_special_token (token): The special token used to split the corpus.
</span></span></span><span class="line"><span class="cl"><span class="s2">            special_tokens (list[Token]): List of special tokens.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns:
</span></span></span><span class="line"><span class="cl"><span class="s2">            PreTokenCount: A dictionary-like object mapping pre-tokens to their counts.
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在具体实现时，为了采用多进程以提高处理效率， <code>_process_chunk</code> 负责统计一个 chunk 中 pre-token 出现的频次。如下所示，一个语料库会被以 special token 为边界划分为多个 chunk，每个 chunk 由多篇被 special token 分隔的文档组成。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20251212160711.png"></p>
<p>在多进程的实现中，每个进程每次处理一个 chunk。统计每个 chunk 中各个 pre-token 的出现次数，<code>_process_chunk</code> 先按照 special token 将 chunk 划分为 document，然后使用 Counter 模块和给定的 pre-token 的正则表达式统计单篇 document 中每个 pre-token 的出现次数。</p>
<p>在 <code>_process_chunk</code> 的帮助下，多进程的实现就变得很简单，只需将给定的语料库划分为指定数量的 chunk，然后使用进程池派发这些任务，最后将每个 chunk 的词频归约在一起即可。</p>
<p>参考实现：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">MultiProcessPreTokenizer</span><span class="p">(</span><span class="n">PreTokenizer</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">_process_chunk_with_boundry</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span> <span class="n">corpos_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">end</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">special_tokens</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Token</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">PreTokenCount</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">corpos_path</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&#34;br&#34;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">f</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">start</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">chunk</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">pre_token_count</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_process_chunk</span><span class="p">(</span><span class="n">chunk</span><span class="p">,</span> <span class="n">special_tokens</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">pre_token_count</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">corpos_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">split_special_token</span><span class="p">:</span> <span class="n">Token</span><span class="p">,</span> <span class="n">special_tokens</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Token</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">PreTokenCount</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">final_pre_token_count</span><span class="p">:</span> <span class="n">PreTokenCount</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">file_size</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">getsize</span><span class="p">(</span><span class="n">corpos_path</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">num_cpus</span> <span class="o">=</span> <span class="n">cpu_count</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">desired_chunks</span> <span class="o">=</span> <span class="n">num_cpus</span> <span class="o">*</span> <span class="mi">100</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">chunk_boundaries</span> <span class="o">=</span> <span class="n">find_chunk_boundaries</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">file_path</span><span class="o">=</span><span class="n">corpos_path</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">desired_num_chunks</span><span class="o">=</span><span class="n">desired_chunks</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">split_special_token</span><span class="o">=</span><span class="n">split_special_token</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">chunks_args</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">chunk_boundaries</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">start</span> <span class="o">=</span> <span class="n">chunk_boundaries</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="n">end</span> <span class="o">=</span> <span class="n">chunk_boundaries</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="n">chunks_args</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">corpos_path</span><span class="p">,</span> <span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span> <span class="n">special_tokens</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&#34;Splitting task into </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">chunks_args</span><span class="p">)</span><span class="si">}</span><span class="s2"> chunks.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">with</span> <span class="n">Pool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="n">num_cpus</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="k">as</span> <span class="n">pool</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">chunk_iter</span> <span class="o">=</span> <span class="n">pool</span><span class="o">.</span><span class="n">imap_unordered</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_worker_wrapper</span><span class="p">,</span> <span class="n">chunks_args</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">chunk_result</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="o">.</span><span class="n">tqdm</span><span class="p">(</span><span class="n">chunk_iter</span><span class="p">,</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">chunks_args</span><span class="p">),</span> <span class="n">desc</span><span class="o">=</span><span class="s2">&#34;Pre-tokenizing&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="n">token</span><span class="p">,</span> <span class="n">count</span> <span class="ow">in</span> <span class="n">chunk_result</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
</span></span><span class="line"><span class="cl">                    <span class="n">final_pre_token_count</span><span class="p">[</span><span class="n">token</span><span class="p">]</span> <span class="o">+=</span> <span class="n">count</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="s2">&#34;Takes </span><span class="si">{:.2f}</span><span class="s2"> seconds to pre-tokenize, speed: </span><span class="si">{:.2f}</span><span class="s2"> bytes/second&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">end_time</span> <span class="o">-</span> <span class="n">start_time</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">file_size</span> <span class="o">/</span> <span class="p">(</span><span class="n">end_time</span> <span class="o">-</span> <span class="n">start_time</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">final_pre_token_count</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nd">@staticmethod</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">_worker_wrapper</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">tokenizer_instance</span> <span class="o">=</span> <span class="n">MultiProcessPreTokenizer</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">tokenizer_instance</span><span class="o">.</span><span class="n">_process_chunk_with_boundry</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="tokenizertrainer">TokenizerTrainer</h3>
<p>TokenizerTrainer 负责在给定词频统计结果上运行 BPE 算法。BPE 算法简而言之，就是找出出现频次最高的组合并合并，为了高效实现 BEP 算法，首先要回答以下几个问题：</p>
<ol>
<li>
<p>怎么初始化各个组合（Pair）的出现频次？<br>
Pre-Tokenizer 提供的是每个 pre-token 的频次，可以通过遍历每个 pre-token 中可能出现的组合并累加频次即可。例如，对于 pre-token <code>Hello </code>，可以贡献 <code>(H, e), (e, l), (l, l), (l, o), (o, 空格)</code> 这六组频次。</p>
</li>
<li>
<p>合并后怎么维护各个组合的频次？<br>
每当一个组合被合并后，各个组合的频次需要更新，以 <code>Hello </code> 中 <code>el</code> 合并为例，需要更新的频次包括：</p>
</li>
</ol>
<ul>
<li>新产生组合： <code>el</code> 作为一个独立的 token，与其邻接的两个 token 将产生新的组合 <code>(H, el)</code> 和 <code>(el, l)</code>，这两个组合需要新增。</li>
<li>频次减少的组合：<code>el</code> 合并之后，这个组合不复存在，其词频需要设置为 0。除此之外，与被合并的 token 邻接的其它 token 组成的组合频次也许相应减少，即 <code>(H, e)</code> 和 <code>(l, l)</code> 的频次需要减少 <code>Hello</code> 这个单词的次数。</li>
</ul>
<ol start="3">
<li>
<p>合并后如何快速定位到受影响的 pre-token？<br>
在上一问中，我们解决了在 pre-token 已知的情况下词频的维护逻辑。但是如何快速找到受影响的 pre-token？朴素的方法是直接遍历整个 pre-token，显然每次都要遍历的方案完全不可接受。可行的方案是我们维护一张 pair 到 pre-token 的映射表，表示含有这个 pair 的 pre-token 列表。这张表在组合频次初始化时也一起被初始化，在组合被合并时也一起被更新，从而在 merge 过程中找到受影响的 pre-token 列表。</p>
</li>
<li>
<p>如何记录当前 pre-token 的状态？<br>
pre-teken 的状态指的是当前的 pre-token 由哪些 token 组成。举个栗子，BPE 算法刚开始时，<code>Hello </code> 这个 pre-token 的是由 <code>(H, e, l, l, o, 空格)</code> 这六个 token 组成的，在算法后期，其可能是由 <code>(Hel, l, o空格)</code> 这三个 token 组成。此时如果需要合并 <code>(l, o空格)</code>，在更新组合的频次时就要知道 <code>l</code> 的前一个 token 是啥，而非简单查询 <code>l</code> 前一个字符是什么，因此我们还需要一个字典来维护每个 pre-token 当前的状态。</p>
</li>
<li>
<p>如何获取频率最大的组合？<br>
朴素的方案是每次都遍历整个词频表，时间复杂度是 O(n)。我们可以使用最大堆来优化这一过程，从而可以将单次获取并维护最大值的时间复杂度降为 O(log n)。</p>
</li>
</ol>
<p>回答上述问题后，TokenizerTrainer 在训练过程中需要维护的数据结构就呼之欲出了，包括：</p>
<ul>
<li><code>pair_counts</code>：各个组合的出现频次（组合是一个两个 token 组成的元组 <code>tuple(token, token)</code>）</li>
<li><code>pre_token_states</code>：每个 pre-token 当前的状态（组成），即当前 pre-token 如何使用 token 来表示</li>
<li><code>pair_to_pretokens</code>：一个字典，表示含有指定组合的 pre-token 的 token 列表</li>
<li><code>pair_heap</code> ：一个列表，用于在 Python 中实现最大堆</li>
</ul>
<h4 id="train">train</h4>
<p>在 train 方法中，首先调用 <code>init</code> 方法对 <code>pre_token_states</code>、<code>pair_counts</code>、<code>pair_to_pretokens</code>、<code>pair_heap</code> 这四个数据结构进行初始化。</p>
<p>在主循环中获取出现次数最多的组合作为新的 token，并合并和维护上述数据结构，直至词汇表达到目标值或者没有可合并的组合。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">Vocab</span><span class="p">,</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">bytes</span><span class="p">,</span> <span class="nb">bytes</span><span class="p">]]]:</span>
</span></span><span class="line"><span class="cl">	<span class="bp">self</span><span class="o">.</span><span class="n">_init</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="n">num_merges_needed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_vocab_size</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="n">num_merges_needed</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="k">return</span> <span class="p">{},</span> <span class="bp">self</span><span class="o">.</span><span class="n">merges</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="k">while</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_vocab_size</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">merge_pair</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_determine_merge_pair</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span> <span class="n">merge_pair</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">			<span class="k">break</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">merges</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">merge_pair</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">_merge_pair</span><span class="p">(</span><span class="n">merge_pair</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">merges</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h4 id="_merge_pair">_merge_pair</h4>
<p>TokenizerTrainer 的核心是合并逻辑。合并算法的流程图如下所示，首先根据 Pair 获取包含合并 Pair 的所有 pre-token 列表，然后对于其中的每个 pre-token，遍历其状态（token 组成），如果是 Pair，则将 <code>(..., A, B, ...)</code> 替换为 <code>(..., AB, ...)</code>，然后减少 <code>(PrevA, A)</code> 和 <code>(B, NextB)</code> 的频次，并增加 <code>(PrevA, AB)</code> 和 <code>(AB, NextB)</code> 的频次，最后设置 <code>(A, B)</code> 的出现频次为 0。</p>
<p>参考源码：<a href="https://github.com/LittleHeroZZZX/cs366/blob/8afb4e7e1973a3da6f0cede46d8d35a0959f2982/assignment1-basics/src/tokenization/tokenizer_trainer.py#L69-L206">cs366/assignment1-basics/src/tokenization/tokenizer_trainer.py at 8afb4e7e1973a3da6f0cede46d8d35a0959f2982 · LittleHeroZZZX/cs366 · GitHub</a></p>
<p><img alt="_merge_pair 流程图" loading="lazy" src="https://pics.zhouxin.space/20251218185517.png"></p>
<h3 id="tokenizertrainerc">TokenizerTrainerC</h3>
<p>上述 Python 版本的存在一定的性能瓶颈，所以使用 C++ 重构了 TokenizerTrainerC，核心逻辑与 Python 一致，提速十几倍，可以直接查阅代码：</p>
<ul>
<li>C++ 实现： <a href="https://github.com/LittleHeroZZZX/cs366/blob/2b4fb5a84a382aeafbd16775fcc23cba021f9831/assignment1-basics/csrc/tokenizer_trainer.cc">cs366/assignment1-basics/csrc/tokenizer_trainer.cc at 2b4fb5a84a382aeafbd16775fcc23cba021f9831 · LittleHeroZZZX/cs366 · GitHub</a></li>
<li>Python 接口： <a href="https://github.com/LittleHeroZZZX/cs366/blob/8afb4e7e1973a3da6f0cede46d8d35a0959f2982/assignment1-basics/src/tokenization/tokenizer_trainer.py#L208-L245">cs366/assignment1-basics/src/tokenization/tokenizer_trainer.py at 8afb4e7e1973a3da6f0cede46d8d35a0959f2982 · LittleHeroZZZX/cs366 · GitHub</a></li>
</ul>
<h2 id="构建-transformer-语言模型组件">构建 Transformer 语言模型组件</h2>
<p>在本部分，我们需要使用基本的 Tensor 操作来构建 Transformer 各个组件模块。绝大多数组件难度不大，根据定义照抄即可。各组件实现思路：</p>
<ul>
<li>Linear：直接矩乘，注意权重需要转置</li>
<li>Embedding：熟练运用高级索引</li>
<li>RMSNorm：抄公式；为了数值稳定性，先转 float 再转回去</li>
<li>SwiGLU：抄公式</li>
<li>RoPE：这个难度较大，建议让 Gemini 辅助理解，强烈安利我的这份对话辅导 <a href="https://gemini.google.com/share/22af0e33a7b6">‎Gemini - 直接体验 Google AI 黑科技</a></li>
<li>SDPA：笔者最近在 Paddle 上改造这个 API，比较熟悉，刚上手可能需要多理解一下其中的 shape 变换</li>
<li>MHA：引入了 mask，同样需要熟悉这里面的 shape 变换</li>
<li>Transformer LM：搭积木，注意 LM Head 后不需要 softmax</li>
</ul>
<p>参考实现在：</p>
<ul>
<li>基础组件： <a href="https://github.com/LittleHeroZZZX/cs366/blob/8d800aeb4942e710ac835b1be6f89aecc0bae483/assignment1-basics/src/nn/basic.py">cs366/assignment1-basics/src/nn/basic.py at 8d800aeb4942e710ac835b1be6f89aecc0bae483 · LittleHeroZZZX/cs366 · GitHub</a></li>
<li>前向函数： <a href="https://github.com/LittleHeroZZZX/cs366/blob/8d800aeb4942e710ac835b1be6f89aecc0bae483/assignment1-basics/src/nn/functional.py">cs366/assignment1-basics/src/nn/functional.py at 8d800aeb4942e710ac835b1be6f89aecc0bae483 · LittleHeroZZZX/cs366 · GitHub</a></li>
<li>网络结构： <a href="https://github.com/LittleHeroZZZX/cs366/blob/d0d8413fd2e8e048450d893bc648baf70cbd4258/assignment1-basics/src/nn/networks.py">cs366/assignment1-basics/src/nn/networks.py at d0d8413fd2e8e048450d893bc648baf70cbd4258 · LittleHeroZZZX/cs366 · GitHub</a></li>
</ul>
<h2 id="训练组件">训练组件</h2>
<p>本部分需要实现损失函数、AdamW 优化器、lr schedule 和梯度裁剪。同样，需要注意的点：</p>
<ul>
<li>交叉熵损失
<ul>
<li>logits 转 float 以避免精度和溢出问题</li>
<li>target（label）转 int64，使用 gather 提高效率</li>
</ul>
</li>
<li>SGD：抄公式</li>
<li>AdamW：抄公式</li>
<li>lr schedule：抄公式</li>
<li>梯度裁剪：公式中的梯度的二范数指的是所有梯度拼在一起的二范数，即如果要裁剪，则裁剪所有参数的梯度，而非分参数判断这个参数的梯度是否要裁剪</li>
</ul>
<p>参考实现：</p>
<ul>
<li>交叉熵损失：<a href="https://github.com/LittleHeroZZZX/cs366/blob/8afb4e7e1973a3da6f0cede46d8d35a0959f2982/assignment1-basics/src/nn/functional.py#L28-L33">cs366/assignment1-basics/src/nn/functional.py at 8afb4e7e1973a3da6f0cede46d8d35a0959f2982 · LittleHeroZZZX/cs366 · GitHub</a></li>
<li>优化器： <a href="https://github.com/LittleHeroZZZX/cs366/blob/66c38c225aa1424feb0fc95bb30153391e5ae638/assignment1-basics/src/nn/optimizer.py">cs366/assignment1-basics/src/nn/optimizer.py at 66c38c225aa1424feb0fc95bb30153391e5ae638 · LittleHeroZZZX/cs366 · GitHub</a></li>
<li>lr schedule 和 梯度裁剪： <a href="https://github.com/LittleHeroZZZX/cs366/blob/8afb4e7e1973a3da6f0cede46d8d35a0959f2982/assignment1-basics/src/nn/utils.py#L24-L46">cs366/assignment1-basics/src/nn/utils.py at 8afb4e7e1973a3da6f0cede46d8d35a0959f2982 · LittleHeroZZZX/cs366 · GitHub</a></li>
</ul>
<h2 id="训练循环">训练循环</h2>
<p>本部分我们需要实现 Data Loader、Checkpoint 和训练循环。</p>
<h3 id="data-loader">Data Loader</h3>
<p>load batch 函数就是随机对给定 dataset 按照长度 <code>context_size</code> 采样 <code>batch_size</code> 次：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">load_batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">:</span> <span class="n">npt</span><span class="o">.</span><span class="n">NDArray</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">context_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">max_start_index</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">context_size</span> <span class="o">+</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">    <span class="n">start_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_start_index</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">x_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">dataset</span><span class="p">[</span><span class="n">i</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">context_size</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">start_indices</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="n">y_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">dataset</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">context_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">start_indices</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="n">x_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">x_batch</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">y_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">y_batch</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">x_tensor</span><span class="p">,</span> <span class="n">y_tensor</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="checkpoint">Checkpoint</h3>
<p>保存点机制可以使用 <code>torch.save</code> 机制来实现，需要保存的状态包括：模型状态、优化器状态和迭代步数，将他们组织成一个字段让 torch 来保存即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">iteration</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">os</span><span class="o">.</span><span class="n">PathLike</span> <span class="o">|</span> <span class="n">typing</span><span class="o">.</span><span class="n">BinaryIO</span> <span class="o">|</span> <span class="n">typing</span><span class="o">.</span><span class="n">IO</span><span class="p">[</span><span class="nb">bytes</span><span class="p">],</span>
</span></span><span class="line"><span class="cl"><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">states</span> <span class="o">=</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;model_state_dict&#34;</span><span class="p">:</span> <span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;optimizer_state_dict&#34;</span><span class="p">:</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;iteration&#34;</span><span class="p">:</span> <span class="n">iteration</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="n">out</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">load_checkpoint</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">checkpoint</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">os</span><span class="o">.</span><span class="n">PathLike</span> <span class="o">|</span> <span class="n">typing</span><span class="o">.</span><span class="n">BinaryIO</span> <span class="o">|</span> <span class="n">typing</span><span class="o">.</span><span class="n">IO</span><span class="p">[</span><span class="nb">bytes</span><span class="p">],</span>
</span></span><span class="line"><span class="cl">    <span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="n">states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">checkpoint</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">states</span><span class="p">[</span><span class="s2">&#34;model_state_dict&#34;</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">optimizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">optimizer</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">states</span><span class="p">[</span><span class="s2">&#34;optimizer_state_dict&#34;</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="n">iteration</span> <span class="o">=</span> <span class="n">states</span><span class="p">[</span><span class="s2">&#34;iteration&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">iteration</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="训练循环-1">训练循环</h3>
<p>整个训练过程流程图如下所示，根据配置文件解析训练配置，初始化模型和优化器后扫描输出目录，如果存在检查点，则从断点恢复模型和优化器状态，然后使用 mmap 加载数据集。每个训练 Step 依次计算学习率、加载数据集、计算 logits 和损失、反向传播、梯度裁剪，最后优化器步进。其中，每隔指定 steps 数量，都会保存训练状态到磁盘，并使用 wandb 记录训练指标。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20251219131047.png"></p>
<ul>
<li>配置文件<br>
通过使用 yaml 配置文件，可以高效地记录和管理模型的训练配置，以便后续进行大量的对比和消融实验，参考配置：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-yaml" data-lang="yaml"><span class="line"><span class="cl"><span class="nt">model</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">vocab_size</span><span class="p">:</span><span class="w"> </span><span class="m">10000</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 上下文长度/序列最大长度 (Context length)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">context_length</span><span class="p">:</span><span class="w"> </span><span class="m">256</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># Transformer 块层数 (Number of Transformer layers)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">num_layers</span><span class="p">:</span><span class="w"> </span><span class="m">4</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 隐藏层维度/模型大小 (Hidden dimension)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">hidden_dim</span><span class="p">:</span><span class="w"> </span><span class="m">512</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># Transformer 内部前馈网络的维度 (Inner dimension of FFN, typically 4*hidden_dim)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">inner_dim</span><span class="p">:</span><span class="w"> </span><span class="m">1344</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 多头注意力机制的头数 (Number of attention heads)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">num_heads</span><span class="p">:</span><span class="w"> </span><span class="m">16</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># RoPE (Rotary Position Embedding) 的旋转基数（通常为 10000）</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">theta</span><span class="p">:</span><span class="w"> </span><span class="m">10000.0</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 训练设备 (Device for training: &#39;cuda&#39;, &#39;mps&#39;, or &#39;cpu&#39;)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">device</span><span class="p">:</span><span class="w"> </span><span class="s2">&#34;cuda&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 模型参数的数据类型 (Data type: &#39;float32&#39; or &#39;bfloat16&#39;)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">dtype</span><span class="p">:</span><span class="w"> </span><span class="s2">&#34;bfloat16&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="nt">optimizer</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 学习率 (Learning rate)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">learning_rate</span><span class="p">:</span><span class="w"> </span><span class="m">3.0e-4</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 权重衰减 (Weight decay for AdamW)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">weight_decay</span><span class="p">:</span><span class="w"> </span><span class="m">1e-2</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># AdamW 优化器的一阶矩估计衰减率 (Beta1)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">beta1</span><span class="p">:</span><span class="w"> </span><span class="m">0.9</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># AdamW 优化器的二阶矩估计衰减率 (Beta2)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">beta2</span><span class="p">:</span><span class="w"> </span><span class="m">0.999</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># AdamW 优化器的数值稳定性参数 (Epsilon)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">eps</span><span class="p">:</span><span class="w"> </span><span class="m">1.0e-8</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="nt">training</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 批次大小 (Batch size)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">batch_size</span><span class="p">:</span><span class="w"> </span><span class="m">96</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 总训练迭代次数 (Total training iterations)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">total_iterations</span><span class="p">:</span><span class="w"> </span><span class="m">20000</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 学习率预热的迭代次数 (Warmup iterations)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">warmup_iterations</span><span class="p">:</span><span class="w"> </span><span class="m">1000</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 学习率余弦退火周期迭代次数 (Cosine cycle iterations)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">cosine_cycle_iterations</span><span class="p">:</span><span class="w"> </span><span class="m">20000</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 梯度裁剪的最大 L2 范数 (Max L2 norm for gradient clipping)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">max_l2_norm</span><span class="p">:</span><span class="w"> </span><span class="m">1.0</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 模型保存的间隔（迭代次数） (Checkpoint saving interval)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">checkpoint_interval</span><span class="p">:</span><span class="w"> </span><span class="m">1000</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 模型和配置的输出目录 (Output directory)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">output_dir</span><span class="p">:</span><span class="w"> </span><span class="s2">&#34;./output_full/&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 训练数据文件路径 (Path to training data .npy file)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">train_data</span><span class="p">:</span><span class="w"> </span><span class="s2">&#34;save/data/TS-train.bin&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 验证数据文件路径 (Path to validation data .npy file)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">val_data</span><span class="p">:</span><span class="w"> </span><span class="s2">&#34;save/data/owt_train.bin&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="c"># 保存检查点的步长（在 `train` 函数中用于判断是否保存检查点）</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">save_step</span><span class="p">:</span><span class="w"> </span><span class="m">500</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">log_step</span><span class="p">:</span><span class="w"> </span><span class="m">50</span><span class="w">
</span></span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>WandB<br>
推荐使用 WandB 自动记录实验配置和性能指标，避免手动整理记录实验数据，提升参数搜索和消融实验的效率。</li>
</ul>
<p>train 函数参考实现在： <a href="https://github.com/LittleHeroZZZX/cs366/blob/d0d8413fd2e8e048450d893bc648baf70cbd4258/assignment1-basics/src/nn/train.py">cs366/assignment1-basics/src/nn/train.py at d0d8413fd2e8e048450d893bc648baf70cbd4258 · LittleHeroZZZX/cs366 · GitHub</a></p>
<h2 id="生成文本">生成文本</h2>
<p>文本生成的流程图如下所示，从指定目录加载模型权重和 Tokenizer，转换为 token id 后喂给模型，获取最后一个 token 的预测结果，应用 temperature 和 softmax 后按照概率从大到小排序，保留累计概率不小于 p 的候选词，并按比例从中随机选择一个作为预测结果。预测下一轮时，将上一轮的 token 加入输入序列，重复上述过程，直至输出 endoftext 或者达到用户指定的 token 上限。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20251219132427.png"><br>
decode 参考实现在 <a href="https://github.com/LittleHeroZZZX/cs366/blob/d78dc879e554f47094824deadcdc2801551d0158/assignment1-basics/src/nn/decode.py">cs366/assignment1-basics/src/nn/decode.py at d78dc879e554f47094824deadcdc2801551d0158 · LittleHeroZZZX/cs366 · GitHub</a></p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第二讲：PyTorch 与资源计算</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-1-pytorch-and-resource-accounting/</link>
      <pubDate>Sun, 19 Oct 2025 17:51:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-1-pytorch-and-resource-accounting/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;本文是 CS336 系列笔记的第二讲，PyTorch 与资源计算。本讲比较简单，从两个问题出发，讨论如何计算计算量，并顺带讲一些 PyTorch 的入门概念。本讲的重点不在于 PyTorch 基础，而在于培养“资源计算”的思维模式。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>本文是 CS336 系列笔记的第二讲，PyTorch 与资源计算。本讲比较简单，从两个问题出发，讨论如何计算计算量，并顺带讲一些 PyTorch 的入门概念。本讲的重点不在于 PyTorch 基础，而在于培养“资源计算”的思维模式。</p>
</blockquote>
<p>本节课的目标就是回答两个问题：</p>
<ul>
<li>使用 1024 张 H100 在 15 T token 的数据集上训练 70B 的模型需要多久？</li>
<li>在 8 张 H100 上使用 AdamW 优化器，最大可以训练多大的模型？</li>
</ul>
<h2 id="内存计算">内存计算</h2>
<h3 id="tensor-入门">Tensor 入门</h3>
<p>深度学习的一切都是使用 Tensor 进行存储的，包括参数、梯度、优化器状态、数据、激活层。</p>
<p>入门结束。（逃）</p>
<h3 id="tensor-内存">Tensor 内存</h3>
<p>几乎所有 Tensor 都以浮点数的形式存储，通常默认的数据格式为 FP32。</p>
<p>Tensor 占据的内存取决于数据类型和 Tensor 内元素的个数。</p>
<p>为了容纳更多的参数，业界还引入了 FP16 、 BF16 和 FP8 数据格式。</p>
<p>总而言之：</p>
<ul>
<li>FP32 对于训练已经够够的了，但是需要大量内存</li>
<li>FP8 / FP16 / BF16 训练存在风险，可能会导致不稳定</li>
<li>解决方案：使用混合精度训练</li>
</ul>
<h3 id="tensor-操作">Tensor 操作</h3>
<p>课程介绍了一些 Tensor 的基本操作，此处不赘述。值得一提的是 stride 机制，若不熟悉可以查阅一下。</p>
<h3 id="tensor-操作-flops-计算">Tensor 操作 FLOPs 计算</h3>
<p>一次浮点运算 Float-point Operation FLOP 指的是一次浮点数加法或者浮点数乘法。需要区分的两个概念：</p>
<ul>
<li>FLOPs：float-point operations，浮点计算量，用来表达需要进行多少次浮点计算</li>
<li>FLOP/s：FLOP per second，每秒浮点计算量，用来衡量硬件的计算能力</li>
</ul>
<p>一些直观感受：</p>
<ul>
<li>训练 GPT-3（2020）使用了 3.14e23 FLOPs</li>
<li>训练 GPT-3（2023）被推测使用了 2e25 FLOPs</li>
<li>A100 峰值算力 3.12e14 FLOP/s</li>
<li>H100 系数</li>
</ul>
<p>假定输入数据为 <code>[B, D]</code>，线性层将 <code>D</code> 映射到 <code>K</code> 即权重矩阵形状为 <code>[D, K]</code>。对于输出张量的每个元素，都需要通过将两个长度为 D 的向量逐元素相乘后相加，即 2D 个 FLOPs，一共有 <code>B x K</code> 个输出元素，所以一次矩乘的 FLOPs 为 <code>2 x B x D x K</code>。</p>
<p><code>mfu = 模型实际每秒FLOP / 硬件理论最大每秒FLOP</code>，一般来说 MFU 大于 0.5 已经相当不错了。</p>
<h2 id="模型">模型</h2>
<h3 id="参数初始化">参数初始化</h3>
<p>简单来说，参数初始化也是一门学问，否则会导致梯度爆炸或者消失，这门课对此没有详细介绍。CMU 10414 对此有从数学上的推导，见我当时的学习笔记：<a href="https://www.zhouxin.space/notes/notes-on-cmu-10-414-deep-learning-system/#%E5%88%9D%E5%A7%8B%E5%8C%96">《CMU 10-414 deep learning system》学习笔记 | 周鑫的个人博客</a>。</p>
<h3 id="其它内容">其它内容</h3>
<p>这里还介绍了如何训练一个模型、优化器、检查点等内容，过于基础，此处不赘述。</p>
<h3 id="混合精度训练">混合精度训练</h3>
<p>低精度可以减少显存占用并加快计算速度，同时避免低精度带来的不稳定，我们的策略是：默认使用 FP32，并尽可能使用 BF16、FP8，具体来说，在前向时使用低精度，在反向时使用高精度。</p>
<h2 id="小结">小结</h2>
<p>本讲作为 PyTorch 的入门课，介绍的都是一些很基本的概念。涉及的 PyTorch 的用法和训练过程都是一个引子，重点还是与培养资源计算的思考模式，毕竟大模型时代，效率就是金钱。</p>
]]></content:encoded>
    </item>
    <item>
      <title>CS336 学习笔记之第一讲：总览与分词器</title>
      <link>https://www.zhouxin.space/notes/notes-on-cs336-lecture-1-overview-and-tokenization/</link>
      <pubDate>Sun, 19 Oct 2025 12:17:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cs336-lecture-1-overview-and-tokenization/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;本文是 CS336 第一讲的学习笔记，主要介绍开设这门课程的背景和动机，并对课程主要内容做了概览。还介绍了不同的分词器基本原理及其优缺点。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 id=&#34;引入&#34;&gt;引入&lt;/h2&gt;
&lt;h3 id=&#34;为什么要学习这门课程&#34;&gt;为什么要学习这门课程&lt;/h3&gt;
&lt;p&gt;现象：研究人员与底层技术越来越远。八年前，他们需要自己实现和训练模型；六年前，他们需要下载一个模型，并进行微调；现如今，他们仅仅修改模型的提示词。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>本文是 CS336 第一讲的学习笔记，主要介绍开设这门课程的背景和动机，并对课程主要内容做了概览。还介绍了不同的分词器基本原理及其优缺点。</p>
</blockquote>
<h2 id="引入">引入</h2>
<h3 id="为什么要学习这门课程">为什么要学习这门课程</h3>
<p>现象：研究人员与底层技术越来越远。八年前，他们需要自己实现和训练模型；六年前，他们需要下载一个模型，并进行微调；现如今，他们仅仅修改模型的提示词。</p>
<p>当然，上述现象并不是坏事，通过抽象程度的不断提高，研究人员才得进行更高效地工作。但是，基于 <a href="https://zh.wikipedia.org/wiki/%E6%8A%BD%E8%B1%A1%E6%BC%8F%E6%B4%9E%E5%AE%9A%E5%BE%8B">抽象漏洞定律</a>，抽象不可能完美隐藏底层细节。并且，仍有许多基础性的工作亟待开发，这些工作需要对底层有着深入的洞察。</p>
<p>这门课程将通过从头建立语言模型的方式来让大家彻底理解语言模型各个层面的技术。</p>
<h3 id="制约因素">制约因素</h3>
<p>学习现代大语言，有两大制约因素：</p>
<ol>
<li>在大模型工业化时代下，训练一个大模型的成本极高，需要大量的硬件、时间和金钱成本。</li>
<li>最前沿模型（GPT 系列）几乎不公开详细的技术报告，缺乏学习材料。</li>
</ol>
<h3 id="小模型不具备代表性">小模型不具备代表性</h3>
<p>我们可以通过训练小模型（参数量小于 1B）来弥补上述制约因素，但是小模型的经验并不能完全适用于大模型。</p>
<ul>
<li>
<p>案例一：随着规模的扩大，注意力层和全连接层的计算量的比例将改变</p>
<ul>
<li>小规模下 FFN 不显著</li>
<li>大规模下 FFN 主导<br>
<img alt="不同规模下注意力层和全连接层的计算量变化情况" loading="lazy" src="https://pics.zhouxin.space/202510191259457.webp"></li>
</ul>
</li>
<li>
<p>案例二：随着参数量的增加，模型能力显著提高<br>
<img alt="不同参数下模型能力得分" loading="lazy" src="https://pics.zhouxin.space/202510191302227.webp"></p>
</li>
</ul>
<h3 id="从这门课能学到什么">从这门课能学到什么？</h3>
<p>有三方面的知识：</p>
<ul>
<li>运行原理：大模型的各个部件的工作原理（例如什么是 Transformer、如何构建 GPU 上并行的模型）</li>
<li>思维模式（Mindset）：旨在将硬件性能发挥到极致，并认真对待规模化问题的理念（例如，扩展定律/scaling laws）</li>
<li>实践直觉（Intuitions）：凭经验判断哪些数据和建模决策能够带来良好的准确率。</li>
</ul>
<h3 id="惨痛的教训">惨痛的教训</h3>
<p>有一种错误认知是：算法不重要，重要的是规模。<br>
正确的认知应该是：可以被规模化扩大的算法最重要。</p>
<p>正确率=效率 x 资源。在大规模下，算法效率格外作用，因为承担不起试错成本。<a href="https://arxiv.org/abs/2005.04305">[2005.04305] Measuring the Algorithmic Efficiency of Neural Networks</a> 这篇文章揭示了，在 ImageNet 上算法的不同可以带来 44x 的效能提升，远超摩尔定律带来的 11x 的硬件性能提升。</p>
<h2 id="研究现状">研究现状</h2>
<ul>
<li>基础组件
<ul>
<li>Sequence to Sequence 建模（2014）</li>
<li>Adam 优化器（2014）</li>
<li>注意力机制（2014）</li>
<li>Transformer 架构（2017）</li>
<li>MoE（2017）</li>
<li>模型并行技术（2018-2019）</li>
</ul>
</li>
<li>早期基础模型
<ul>
<li>ElMo：基于 LSTMs 做预训练，在下游任务上微调（2018）</li>
<li>BERT：基于 Transformer 做预训练，在下游任务上微调（2018）</li>
<li>Google&rsquo;s T5：将所有任务都映射成 text to text 任务（2019）</li>
</ul>
</li>
<li>模型规模化探索（闭源模型）
<ul>
<li>OpenAI&rsquo;s GPT-2 (1.5B)：生成流畅文本，展现了初步的零样本（zero-shot）学习能力，并采取分阶段发布策略（2019）</li>
<li>OpenAI&rsquo;s GPT-3 (175B)：展示出强大的上下文学习（in-context learning）能力，但模型闭源（2020）</li>
<li>Google&rsquo;s PaLM (540B)：进行了更大规模的训练，但后来被认为是欠训练的（undertrained）（2022）</li>
<li>DeepMind&rsquo;s Chinchilla (70B)：提出了计算最优的缩放定律（compute-optimal scaling laws），认为在同等算力下，应该用更多数据训练更小的模型（2022）</li>
</ul>
</li>
<li>开源模型发展
<ul>
<li>EleutherAI (The Pile &amp; GPT-J)：发布了大规模开放数据集 The Pile 和开源模型 GPT-J，推动了开源生态的发展 (2020-2021)</li>
<li>Meta&rsquo;s OPT (175B)：尝试复刻 GPT-3，但遇到了大量的硬件挑战 (2022)</li>
<li>Hugging Face / BigScience&rsquo;s BLOOM：一个大型多语言开源模型，项目重点关注数据的来源和治理 (2022)</li>
<li>Meta&rsquo;s Llama 系列：发布了多个版本的 Llama 模型，在开源社区产生了巨大影响，成为许多后续模型的基础 (2023-2024)</li>
<li>Alibaba&rsquo;s Qwen (通义千问) 系列：阿里巴巴推出的一系列功能强大的开源模型 (2024)</li>
<li>DeepSeek&rsquo;s models (深度求索)：由深度求索公司发布的一系列高性能开源模型 (2024)</li>
<li>AI2&rsquo;s OLMo 2：由艾伦人工智能研究所（AI2）推出的完全开放模型，包括训练数据和代码 (2024)</li>
</ul>
</li>
<li>模型的开源程度
<ul>
<li>只开放 API，不开源模型和权重</li>
<li>开源模型、权重和技术报告，但是不开源数据集</li>
<li>开放模型、权重和数据集</li>
</ul>
</li>
<li>当下最前沿的模型（2025）
<ul>
<li>OpenAI o3</li>
<li>Anthropic Claude Sonnet 3.7</li>
<li>xAI Grok 3</li>
<li>Google Gemini 2.5</li>
<li>Meta Llama 3.3</li>
<li>DeepSeek r1</li>
<li>Alibaba Qwen 2.5</li>
<li>Tencent Hunyuan-T1</li>
</ul>
</li>
</ul>
<h2 id="课程总览">课程总览</h2>
<h3 id="基础">基础</h3>
<p>目标：对模型全流程有一个大致了解。</p>
<p>在此阶段将学习分词器、模型架构和训练。与之匹配的任务一中，我们要实现 BPE 分词器、Transformer、交叉熵、AdamW 优化器和训练循环，并在特定数据集上做训练。</p>
<h3 id="系统">系统</h3>
<p>目标：榨干硬件性能</p>
<p>在此阶段将学习：核函数、并行化和推理技术。与之匹配的任务二中，我们要实现融合 RMSNorm、分布式并行训练、优化器状态切分，并对实现进行性能测试和分析。</p>
<h3 id="缩放定律">缩放定律</h3>
<p>目标：在小规模上做实验，并据此预测大规模下的超参数和损失</p>
<p>问题：在给定计算量 FLOPs 的预算下，是应该使用更大的模型还是在更多的 token 上训练？</p>
<p>学习计算最优化缩放定律。与之匹配的任务三中，我们将定义一个 API，其能够基于之前的运行结果预测制定超参下模型的损失值。以及绘制一个曲线，拟合不同算力预算下的损失值。并给出指定预算下让损失最小的超参配置。</p>
<h3 id="数据">数据</h3>
<p>问题：我们希望模型具备什么能力？多模态、代码、数学？这决定了我们需要什么样的数据。</p>
<p>与之匹配的任务四中，我们将把爬取的 html 文件转换为文本，训练分类器对文本的质量和是否有害进行分类，使用 MinHash 识别重复数据，并在置顶预算下最小化困惑度。</p>
<h3 id="对齐">对齐</h3>
<p>到此为止，我们得到的是一个基础模型，其擅长预测下一个词。但是仍需要通过对齐使得其在特定任务上表现优秀。</p>
<p>与之匹配的任务五中，我们将实现 SFT、DPO、GRPO。</p>
<h2 id="分词器">分词器</h2>
<h3 id="引入-1">引入</h3>
<p>原始文本是一串使用 Unicode 编码的字符。而语言模型的输出是在一系列 token 表上的概率分布。</p>
<p>因此我们需要一个分词器将字符串编码为 token，以及将 token 解码为字符串。词汇表大小 vocabulary size 指的就是所有可能 token 的个数。</p>
<h3 id="字符分词器-character-tokenizer">字符分词器 Character Tokenizer</h3>
<p>最简单也是最符合直觉的分词器是字符分词器。每个字符都对应 Unicode 编码中的一个编号，通过查表可以将将字符串编码为 tokens。</p>
<p>这个方案的问题在于：</p>
<ul>
<li>字符表会很大</li>
<li>很多词汇几乎不被使用，使得字符表的使用效率很低</li>
</ul>
<h3 id="字节分词器-byte-tokenizer">字节分词器 Byte Tokenizer</h3>
<p>Unicode 编码的文本可以表示为一系列字节流。在 UTF-8 编码中，每个字符都被编码为 1~4 个字节长度。使用这个方案，所有的字符串都被编码为最大 255 token 序列。</p>
<p>这个方案的问题在于：</p>
<ul>
<li>token 序列太长</li>
<li>每个 token 只能表示一个 1 字节</li>
</ul>
<h3 id="按词分类器-word-tokenizer">按词分类器 Word Tokenizer</h3>
<p>这个分类器的思想是将字符串按照每个词进行分割，例如 <code>&quot;Hello world&quot;</code> 分为 <code>[&quot;Hello], &quot; &quot;,  &quot;world&quot;]</code>，然后这个词集合映射到整数序列上。</p>
<p>这个方案的问题在于：</p>
<ul>
<li>字符表是相当有限的，无法编码碰到没见过的词</li>
<li>字符表同样可能很大</li>
</ul>
<h3 id="字节对编码分类器-byte-pair-encoding-bpe">字节对编码分类器 Byte Pair Encoding BPE</h3>
<p>BPE 在 1994 被提出用来进行数据压缩，在 2015 在 NLP 中被应用到机器翻译，并在 2019 年被 GPT-2 使用。</p>
<p>其基本动机是使用单个 token 来表示常见的序列，使用多个 token 来表示罕见的序列。在 GPT-2 中使用词分类器将原始文本分解为一个个片段，然后在每个片段上运行 BPE 算法。</p>
<p>BPE 算法：将字节流视作原始 token，然后将最常出现的相邻 token 对合并。</p>
<p>例如：</p>
<ul>
<li>初始字节流 <code>[22, 134, 245, 22, 134, 22, 245]</code></li>
<li>统计其中出现次数最多的相邻 token 对 <code>(22, 134)</code> 出现两次</li>
<li>将出现最多的 token 对 <code>(22, 134)</code> 合并并产生新 token 256 替换，得到 <code>[256, 245, 256, 22, 245]</code></li>
<li>重复上述过程，只至达到置顶字符表大小或者没有重复的相邻字节对</li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>关于</title>
      <link>https://www.zhouxin.space/about-me/</link>
      <pubDate>Sun, 08 Jun 2025 13:46:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/about-me/</guid>
      <description>&lt;h2 id=&#34;-你好我是周鑫&#34;&gt;👋 你好，我是周鑫！&lt;/h2&gt;
&lt;p&gt;欢迎来到我的个人技术博客！很高兴能在这里与你分享我的学习、思考和实践经验。&lt;/p&gt;
&lt;hr&gt;
&lt;h3 id=&#34;-我是谁&#34;&gt;🧑‍💻 我是谁？&lt;/h3&gt;
&lt;p&gt;我是一名 INTJer，热衷于探索前沿技术，并渴望通过代码创造实际价值。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h2 id="-你好我是周鑫">👋 你好，我是周鑫！</h2>
<p>欢迎来到我的个人技术博客！很高兴能在这里与你分享我的学习、思考和实践经验。</p>
<hr>
<h3 id="-我是谁">🧑‍💻 我是谁？</h3>
<p>我是一名 INTJer，热衷于探索前沿技术，并渴望通过代码创造实际价值。</p>
<p>目前就读于中国科学技术大学软件学院研究生一年级，同时也通过 <a href="https://github.com/PaddlePaddle/Paddle/issues/71313">第八期飞桨黑客松活动</a> 在百度飞桨进行远程实习。</p>
<hr>
<h3 id="-这个博客是关于什么的">🚀 这个博客是关于什么的？</h3>
<p>在这个博客里，你将会看到关于以下主题的内容：</p>
<ul>
<li><strong>[AI 框架特性解析与实践]</strong>
<ul>
<li><a href="https://www.zhouxin.space/notes/baidu-paddlepaddle-starter-plan-summary/">百度飞桨「启航计划」小结——CINN后端Pass改造 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-1/">Programming Massively Parallel Processors A Hands-on Approach 4th Edition 学习笔记 Part 1 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-2/">Programming Massively Parallel Processors A Hands-on Approach 4th Edition 学习笔记 Part 2 | 周鑫的个人博客</a></li>
</ul>
</li>
<li><strong>[C++ 学习笔记]</strong>:
<ul>
<li><a href="https://www.zhouxin.space/notes/notes-on-effective-cpp-3rd-ed/">Effective Cpp 第三版学习笔记 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/%E5%A6%82%E4%BD%95%E5%9C%A8vscode%E4%B8%AD%E4%BC%98%E9%9B%85%E5%9C%B0%E9%85%8D%E7%BD%AEcmake--%E4%BB%A5paddlepaddle%E4%B8%BA%E4%BE%8B/">如何在VSCode中“优雅”地配置CMake —— 以PaddlePaddle为例 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/joint-debgugging-of-cuda-and-python-in-vscode/">在VSCode中对CUDA和Python代码进行联合调试 | 周鑫的个人博客</a></li>
</ul>
</li>
<li><strong>[国外公开课学习记录]</strong>
<ul>
<li><a href="https://www.zhouxin.space/notes/notes-on-cmu-10-414-deep-learning-system/">《CMU 10-414 deep learning system》学习笔记 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/cs144-winter-2024-labs/">CS144 Lab 实验笔记 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/tags/efficientml/">EfficientML | 周鑫的个人博客</a></li>
</ul>
</li>
<li><strong>[实际问题的解决过程]</strong>
<ul>
<li><a href="https://www.zhouxin.space/notes/setup-zerotier-moon-server/">搭建ZeroTier MOON服务器 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/install-and-switch-to-specific-version-of-gcc/">安装并切换指定gcc或者g++版本 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/logs/blog-setup-logs/">博客搭建日志 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/notes/using-katex-to-render-math-in-hugo/">在Hugo中使用KATEX渲染数学公式 | 周鑫的个人博客</a></li>
<li><a href="https://www.zhouxin.space/logs/introduce-side-toc-and-reading-percentage-to-papermod/">在PaperMod中引入侧边目录和阅读进度显示 | 周鑫的个人博客</a></li>
</ul>
</li>
<li>以及任何我觉得有趣或值得分享的技术点滴。</li>
</ul>
<p>我创建这个博客的初衷是记录自己的学习过程 同时与更多技术同好交流 。我希望通过文字，能够帮助到正在学习或工作中遇到相似问题的朋友们，同时也期待与大家共同进步。</p>
<hr>
<p>感谢你的阅读！希望你在这个博客能有所收获。</p>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML Lab 2 实验笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-lab-2/</link>
      <pubDate>Tue, 25 Feb 2025 15:35:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-lab-2/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;本文为 EfficientML Lab 2 实验笔记，包含 K-Means 量化、K-Means QAT、线性量化等内容，难度不大，内容丰富。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id=&#34;part-1-k-means-quantization&#34;&gt;Part 1: K-Means Quantization&lt;/h1&gt;
&lt;h2 id=&#34;qustion-1&#34;&gt;Qustion 1&lt;/h2&gt;
&lt;p&gt;第一个问题是实现 K-means 量化的核心算法，其中 K-means 本身是调库实现的。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>本文为 EfficientML Lab 2 实验笔记，包含 K-Means 量化、K-Means QAT、线性量化等内容，难度不大，内容丰富。</p>
</blockquote>
<h1 id="part-1-k-means-quantization">Part 1: K-Means Quantization</h1>
<h2 id="qustion-1">Qustion 1</h2>
<p>第一个问题是实现 K-means 量化的核心算法，其中 K-means 本身是调库实现的。</p>
<p>第一小问求簇数，n bit 可以表示 <code>2^n</code> 个簇。第二小问是根据已有的 Codebook 表示量化后的张量，使用 Tensor 的索引表示即可。需要注意的是，codebook 可能是调用者传入的而非一定是由我们计算得到的，因此在表示的时候要使用 codebook 的成员来获取 <code>centroids</code> 和 <code>labels</code>，否则在后面会报错。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">codebook</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">############### YOUR CODE STARTS HERE ###############</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># get number of clusters based on the quantization precision</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># hint: one line of code</span>
</span></span><span class="line"><span class="cl">        <span class="n">n_clusters</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">**</span> <span class="n">bitwidth</span>
</span></span><span class="line"><span class="cl">        <span class="c1">############### YOUR CODE ENDS HERE #################</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># use k-means to get the quantization centroids</span>
</span></span><span class="line"><span class="cl">        <span class="n">kmeans</span> <span class="o">=</span> <span class="n">KMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="n">n_clusters</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;euclidean&#39;</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">labels</span> <span class="o">=</span> <span class="n">kmeans</span><span class="o">.</span><span class="n">fit_predict</span><span class="p">(</span><span class="n">fp32_tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">centroids</span> <span class="o">=</span> <span class="n">kmeans</span><span class="o">.</span><span class="n">centroids</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">codebook</span> <span class="o">=</span> <span class="n">Codebook</span><span class="p">(</span><span class="n">centroids</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE STARTS HERE ###############</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># decode the codebook into k-means quantized tensor for inference</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># hint: one line of code</span>
</span></span><span class="line"><span class="cl">    <span class="n">quantized_tensor</span> <span class="o">=</span> <span class="n">codebook</span><span class="o">.</span><span class="n">centroids</span><span class="p">[</span><span class="n">codebook</span><span class="o">.</span><span class="n">labels</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">fp32_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE ENDS HERE #################</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="question-2">Question 2</h2>
<p>略</p>
<h1 id="part-2-trained-k-means-quantization">Part 2: Trained K-Means Quantization</h1>
<h2 id="question-3">Question 3</h2>
<p>在低比特量化后模型掉点很厉害，因此要进行 QAT。量化后的权重的梯度推导为：</p>


<div>$$

\frac{\partial \mathcal{L} }{\partial C_k} = \sum_{j} \frac{\partial \mathcal{L} }{\partial W_{j}} \frac{\partial W_{j} }{\partial C_k} = \sum_{j} \frac{\partial \mathcal{L} }{\partial W_{j}} \mathbf{1}(I_{j}=k)

$$</div>

<p>但在本实验中，简单起见，我们使用相同簇的原始权重的均值作为量化后的该簇更新后的值。代码实现就一行：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">codebook</span><span class="o">.</span><span class="n">centroids</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">fp32_tensor</span><span class="p">[</span><span class="n">codebook</span><span class="o">.</span><span class="n">labels</span> <span class="o">==</span> <span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>最终不同量化位数得到的性能指标如下表所示，2 bits 掉点很夸张，QAT 后也难以恢复到原始性能。</p>
<table>
  <thead>
      <tr>
          <th>量化位数</th>
          <th>掉点率</th>
          <th>微调后掉点率</th>
          <th>微调轮数</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>8 bits</td>
          <td>0.17%</td>
          <td>0.17%</td>
          <td>0</td>
      </tr>
      <tr>
          <td>4 bits</td>
          <td>13.87%</td>
          <td>0.49%</td>
          <td>1</td>
      </tr>
      <tr>
          <td>2 bits</td>
          <td>82.95%</td>
          <td>1.75%</td>
          <td>5</td>
      </tr>
  </tbody>
</table>
<h1 id="part-3-linear-quantization">Part 3: Linear Quantization</h1>
<h2 id="question-4">Question 4</h2>
<p>本问实现的是线性量化的核心函数，即给定张量、量化位宽、缩放系数、零点，计算量化后的张量。根据线性量化公式：</p>


<div>$$

q = r/S &#43; Z

$$</div>

<p>计算即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1">############### YOUR CODE STARTS HERE ###############</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># Step 1: scale the fp_tensor</span>
</span></span><span class="line"><span class="cl">    <span class="n">scaled_tensor</span> <span class="o">=</span> <span class="n">fp_tensor</span> <span class="o">/</span> <span class="n">scale</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># Step 2: round the floating value to integer value</span>
</span></span><span class="line"><span class="cl">    <span class="n">rounded_tensor</span> <span class="o">=</span> <span class="n">scaled_tensor</span><span class="o">.</span><span class="n">round</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE ENDS HERE #################</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">rounded_tensor</span> <span class="o">=</span> <span class="n">rounded_tensor</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE STARTS HERE ###############</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># Step 3: shift the rounded_tensor to make zero_point 0</span>
</span></span><span class="line"><span class="cl">    <span class="n">shifted_tensor</span> <span class="o">=</span> <span class="n">rounded_tensor</span> <span class="o">+</span> <span class="n">zero_point</span>
</span></span><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE ENDS HERE #################</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>值得一提的是，在 step 4 中，执行了一步将溢出的结果压缩到范围之内的操作。</p>
<h2 id="question-5">Question 5</h2>
<p>计算缩放系数和零点的公式分别为：</p>


<div>$$

\begin{align*}
S&amp;=(r_{\mathrm{max}} - r_{\mathrm{min}}) / (q_{\mathrm{max}} - q_{\mathrm{min}})\\
Z &amp;= \mathrm{int}(\mathrm{round}(q_{\mathrm{min}} - r_{\mathrm{min}} / S))
\end{align*}

$$</div>

<p>代码照抄即可。</p>
<p>权重有一个特殊的性质：分布通常都是关于 0 点对称的，因此权重量化的零点可以直接设置为 0。</p>
<p>此外，经验表明，对卷积核进行量化时，按照输出通道逐通道量化能够取得更好的表现</p>
<h2 id="question-6-8">Question 6-8</h2>
<p>在此之前，实验文档首先推导了考虑线性量化的全连接层和卷积层的表达式，推导过程进行了一系列代入、假设和化简，主要包括：</p>


<div>$$

\begin{align*}
Z_{\mathrm{weight}}&amp;=0\\
r_{\mathrm{weight}} &amp;= S_{\mathrm{weight}}q_{\mathrm{weight}}\\
Z_{\mathrm{bias}} &amp;= 0\\
S_{\mathrm{bias}} &amp;= S_{\mathrm{input}} \cdot S_{\mathrm{weight}}
\end{align*}

$$</div>

<p>最终得到的结论为：</p>


<div>$$

\begin{align*}
q_{\mathrm{output}} &amp;= (\mathrm{CONV}[q_{\mathrm{input}}, q_{\mathrm{weight}}] &#43; Q_{\mathrm{bias}}) \cdot (S_{\mathrm{input}}S_{\mathrm{weight}} / S_{\mathrm{output}}) &#43; Z_{\mathrm{output}}\\
q_{\mathrm{output}} &amp;= (\mathrm{Linear}[q_{\mathrm{input}}, q_{\mathrm{weight}}] &#43; Q_{\mathrm{bias}})\cdot (S_{\mathrm{input}} \cdot S_{\mathrm{weight}} / S_{\mathrm{output}}) &#43; Z_{\mathrm{output}}\\
\text{其中，}Q_{\mathrm{bias}} &amp;= q_{\mathrm{bias}} - \mathrm{Linear}[Z_{\mathrm{input}}, q_{\mathrm{weight}}]
\end{align*}

$$</div>

<p>Q7 和 Q8 的代码实现相同，需要注意的是在对 output 进行缩放时，由于是逐通道量化的，因此权重的缩放系数是个张量，需要处理好形状以便进行广播。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE STARTS HERE ###############</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># Step 2: scale the output</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#         hint: 1. scales are floating numbers, we need to convert output to float as well</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#               2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc]</span>
</span></span><span class="line"><span class="cl">    <span class="n">output</span> <span class="o">=</span> <span class="n">output</span> <span class="o">*</span> <span class="p">(</span><span class="n">input_scale</span> <span class="o">*</span> <span class="n">weight_scale</span> <span class="o">/</span> <span class="n">output_scale</span><span class="p">)</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1"># Step 3: shift output by output_zero_point</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#         hint: one line of code</span>
</span></span><span class="line"><span class="cl">    <span class="n">output</span> <span class="o">=</span> <span class="n">output</span> <span class="o">+</span> <span class="n">output_zero_point</span>
</span></span><span class="line"><span class="cl">    <span class="c1">############### YOUR CODE ENDS HERE #################</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="question-9">Question 9</h2>
<p>要干的活文档基本都干好了，只剩下一个对输入进行量化的活需要我们完成。照猫画虎，使用 <code>get_quantization_scale_and_zero_point</code> 计算缩放系数和零点，使用 <code>linear_quantize</code> 进行线性量化即可。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1">############### YOUR CODE STARTS HERE ###############</span>
</span></span><span class="line"><span class="cl"><span class="n">x_scale</span><span class="p">,</span> <span class="n">x_zero_point</span> <span class="o">=</span> <span class="n">get_quantization_scale_and_zero_point</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="k">return</span> <span class="n">linear_quantize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="n">x_scale</span><span class="p">,</span> <span class="n">x_zero_point</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1">############### YOUR CODE ENDS HERE #################</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>量化后的模型精度为 92.21% 几乎没掉点。</p>
<p><strong>为什么在线性量化中没有 ReLU 层</strong>：ReLU 层被融合到前一层中网络中，可以减少数据的搬运次数。</p>
<h2 id="question-10">Question 10</h2>
<p>回答来自 deepseek：</p>
<table>
  <thead>
      <tr>
          <th><strong>量化方法</strong></th>
          <th><strong>核心优势</strong></th>
          <th><strong>核心劣势</strong></th>
          <th><strong>适用场景</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>K-means 量化</strong></td>
          <td>高精度（非均匀数据）</td>
          <td>计算复杂、硬件支持差</td>
          <td>数据分布复杂、对精度敏感、专用硬件场景</td>
      </tr>
      <tr>
          <td><strong>线性量化</strong></td>
          <td>低延迟、硬件友好、易部署</td>
          <td>对非均匀数据精度低</td>
          <td>实时推理、通用硬件、动态范围稳定场景</td>
      </tr>
  </tbody>
</table>
<h1 id="小结">小结</h1>
<p>实验文档本身体量很大，知识点也很丰富，但是大多数代码都已经给出了，每个回答只需要写一行或者两行代码，并且周围也给出了充足的提示。这使得实验本身缺乏挑战性，有点鸡肋。</p>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML 第六讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-6th-lecture/</link>
      <pubDate>Mon, 17 Feb 2025 20:30:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-6th-lecture/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;本文介绍了训练后量化（PTQ）和量化感知训练（QAT）技术，PTQ 通过 Per-Tensor/Channel/Vector 等不同粒度划分量化参数，结合动态范围裁剪（校准集统计或 KL 散度优化）和 AdaRound 学习式舍入来平衡精度与效率；QAT 则在前向传播中模拟量化并利用直通估计器（STE）绕过梯度断层，而二元/三元量化通过引入可学习缩放因子减少极低比特（1-2bit）下的精度损失，在压缩模型的同时实现硬件加速与内存优化。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>本文介绍了训练后量化（PTQ）和量化感知训练（QAT）技术，PTQ 通过 Per-Tensor/Channel/Vector 等不同粒度划分量化参数，结合动态范围裁剪（校准集统计或 KL 散度优化）和 AdaRound 学习式舍入来平衡精度与效率；QAT 则在前向传播中模拟量化并利用直通估计器（STE）绕过梯度断层，而二元/三元量化通过引入可学习缩放因子减少极低比特（1-2bit）下的精度损失，在压缩模型的同时实现硬件加速与内存优化。</p>
</blockquote>
<h1 id="训练后量化-post-training-quantization-ptq">训练后量化 Post-training Quantization (PTQ)</h1>
<h2 id="量化粒度">量化粒度</h2>
<ul>
<li>Per-Tensor<br>
整个张量共享一个缩放因子。对于一些小模型，精度损失得比较厉害。绘制出一个 Tensor 不同通道的箱式图，可以发现有的通道之间数据范围差异很大，因此不适合在 Tensor 尺度上使用统一的量化参数。</li>
</ul>
<p><img alt="不同通道的箱式图" loading="lazy" src="https://pics.zhouxin.space/202502172125733.webp"></p>
<ul>
<li>
<p>Per-Channel<br>
Per-Channel 为每个通道都计算了独立的放缩因子，该方案在反量化后能够得到更多样的权重表示，量化误差也相对应地更低。<br>
<img alt="2 bit Per-Channel 线性量化示意图" loading="lazy" src="https://pics.zhouxin.space/202502172130025.webp"></p>
</li>
<li>
<p>Group wise</p>
</li>
</ul>
<p><strong>Per-Vector</strong>在传统 Per-Tensor 量化的基础上，为每个子向量增加一个额外的整型缩放因子，即反量化公式变为：</p>


<div>$$

r=\gamma \cdot S_q(q-Z)

$$</div>

<p>整个 Tensor 共享同一个高精度的缩放因子，达成了精度和硬件效率之间的折中。</p>
<p>以 4-bit quantization with 4-bit per-vector scale for every 16 elements 为例，实际的量化位宽为 4+4/16 = 4.25 bits。</p>
<p>Per-vector 本质上是一种多级的缩放方案，通过使用不同的数据表示类型来表示缩放因子，能够实现 MX4/6/9 等缩放方法，从而达成不同程度的量化位宽。<br>
<img alt="不同缩放方法汇总表" loading="lazy" src="https://pics.zhouxin.space/202502172206284.webp"></p>
<h2 id="dynamic-range-clipping">Dynamic Range Clipping</h2>
<p>权重的数据在模型训练结束后就已经固定，因此其范围也是固定的。而激活层随输入数据的变化而变化，数据范围会很大，确定其数据范围的技术需要专门拿出来讨论。</p>
<p>首先需要收集激活层的统计信息，分为在训练过程中和训练后收集。</p>
<ul>
<li>训练时<br>
使用指数移动平均来统计训练过程中激活层的最大和最小值，其公式为：</li>
</ul>


<div>$$

\hat{r}^{(t)}_{max,min} = \alpha \cdot r^{(t)}_{max,min}&#43;(1-\alpha)\cdot \hat{r}^{(t-1)}_{max,min}

$$</div>

<ul>
<li>训练后<br>
如果无法获取训练时的数据集，可以在模型中运行一些校准集。激活层的数据分布可能如下图所示，两侧的长尾是激活层中的极值，这些极值的存在会降低量化范围的表示能力（通过压缩量化表示范围，可以在有限的数据位宽中提高量化的表示精度）。<br>
<img alt="激活层在校准集上的分布" loading="lazy" src="https://pics.zhouxin.space/202502181053857.webp"><br>
优化目标可以设置为最小化量化的均方误差，即：</li>
</ul>


<div>$$

\min_{|r|_{\text{max}}} \mathbb{E} \left[ (X - Q(X))^2 \right]

$$</div>

<p>对于高斯分布或者拉普拉斯等已知分布，可以求出上述优化目标的数值解。但对于一些不常见的分布，优化目标就是 KL 散度来衡量用量化后的分布来近似激活层的分布的信息损失，即：</p>


<div>$$

\min_{Q} D_{KL}(P || Q) = \min_{Q} \sum_{i=1}^{N} P(x_i) \log \frac{P(x_i)}{Q(x_i)}

$$</div>

<p><img alt="最小化KL散度量化示意图" loading="lazy" src="https://pics.zhouxin.space/202502181125682.webp"></p>
<h2 id="舍入">舍入</h2>
<ul>
<li>AdaRound<br>
常见的舍入策略是舍入到最近的数，例如四舍五入。但是很多时候这并不是一个最佳策略。这里介绍一种基于学习的舍入方法，在舍入前给张量加上一个可学习的参数再进行四舍五入，即：</li>
</ul>


<div>$$

\tilde{w} = \mathrm{round}(\lfloor w \rfloor&#43;\delta),\ \delta\in[0, 1]

$$</div>

<p>那么确定该参数的优化目标为：</p>


<div>$$

\operatorname{argmin}_{V}\|Wx - \mathrm{round}(\lfloor W \rfloor&#43;h(V))x\|_F^2 &#43; \lambda f_{reg}(V)

$$</div>

<p>其中，$V$ 为待学习参数，$h()$ 是一个将输入映射到 $[0, 1]$ 的函数，$f_{reg}$ 是一个正则项，使得 $V$ 接近于 0/1。</p>
<h1 id="量化感知训练-quantization-aware-train-qat">量化感知训练 Quantization-Aware Train (QAT)</h1>
<p>在 K-means 量化中，将每个簇的梯度相加后作为该簇的梯度从而更新权重，线性量化的微调过程则要复杂得多。本节名字为 Quantization-Aware Train，即在训练过程中考虑量化。</p>
<h2 id="伪量化-fake-quantization">伪量化 Fake Quantization</h2>
<p>伪量化示意图如下所示，在前向传播过程中对权重和激活层进行模拟量化（先量化到低精度，再反量化到高精度），再反向传播中使用高精度更新权重。<br>
<img alt="伪量化示意图" loading="lazy" src="https://pics.zhouxin.space/202502211206153.webp"></p>
<h2 id="直通估计器-straight-through-estimator-ste">直通估计器 Straight-Through Estimator, STE</h2>
<p>随之而来会引入一个新的问题，量化是一种舍入，这就导致输出层会变成分段的台阶函数，这个函数的梯度为 0。这里使用 STE 技术，绕过不可导的计算，即认为量化操作对梯度没有影响：</p>


<div>$$

\frac{\partial L}{\partial Q(x)} = \frac{\partial L}{\partial x}

$$</div>

<p>加上 STE 后，整个 STE 的示意图如下所示，直接使用 Loss 对量化后的 W 和 Y 的梯度来估计对量化前的 W 和 Y 的梯度。<br>
<img alt="带有STE的伪量化示意图" loading="lazy" src="https://pics.zhouxin.space/202502211220898.webp"></p>
<h1 id="二元三元量化-binarytenary-quantization">二元/三元量化 Binary/Tenary Quantization</h1>
<p>所谓二元/三元就是将权重量化到 0 和±1，从而大大减少内存占用和加速计算。</p>
<h2 id="二元化">二元化</h2>
<ul>
<li>确定性二元化<br>
确定性二元化根据一个阈值（常见为 0），小于该阈值的权重量化为 -1，大于该阈值的权重量化为 +1。可以预见的是，这个方案掉点很厉害。</li>
<li>随机二元化<br>
一种随机策略为 Binary Connect，权重 r 以概率 $\sigma(\min (\max ((r+1)/2, 0), 1))$ 被量化到 1，否则量化到 -1，该函数的图像如下所示。简单来说，若权重大于 1 或者小于 -1，则被量化到 1 或者 -1，否则按照概率线性增加的形式量化到 1。<br>
<img alt="hard sigmoid 函数图像" loading="lazy" src="https://pics.zhouxin.space/202502211245412.webp"><br>
这个方案的缺点是硬件不友好，其要求硬件在量化过程中生成随机数。</li>
</ul>
<p>可以遇见的，采用这个方案量化后模型掉点很厉害，下图达到了 21.2%。但是，如果给量化后的权重增加一个 32 位的 scale factor，确保量化前后权重的均值一致，尽管量化误差仍旧很大，但是模型精度掉点仅仅只有 0.2。<br>
<img alt="二元量化示意图" loading="lazy" src="https://pics.zhouxin.space/202502211253433.webp"></p>
<h2 id="同时二元量化权重和激活层">同时二元量化权重和激活层</h2>
<p>过于震惊，后面用到在学。</p>
<h2 id="三元量化">三元量化</h2>
<ul>
<li>Ternary Weight Networks, TWN<br>
三元量化将阈值之内的权重量化为 0，之外的量化为±1。阈值的选取一般为权重绝对值的均值乘上 0.7。同样这里需要一个 scale factor。</li>
</ul>
<table>
  <thead>
      <tr>
          <th>ImageNet Top-1 Acc.</th>
          <th>Full Precision</th>
          <th>1 bit (BWN)</th>
          <th>2 bit (TWN)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ResNet-18</td>
          <td>69.6</td>
          <td>60.8</td>
          <td>65.3</td>
      </tr>
  </tbody>
</table>
<p>符合预期得，三元量化精度损失小于二元量化。</p>
<ul>
<li>Trained Ternary Quantization, TTQ<br>
为了进一步减少量化损失，TTQ 将量化后 scale factor 修改为两个可学习的参数，分别用来表示 +1 和 -1 对应的缩放系数，然后寻找这两个参数的最优值。</li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>百度飞桨「启航计划」小结——CINN后端Pass改造</title>
      <link>https://www.zhouxin.space/notes/baidu-paddlepaddle-starter-plan-summary/</link>
      <pubDate>Wed, 08 Jan 2025 00:07:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/baidu-paddlepaddle-starter-plan-summary/</guid>
      <description>&lt;p&gt;在过去八周时间里，我参加了由飞桨开源社区组织的 &lt;a href=&#34;https://github.com/PaddlePaddle/Paddle/issues/69152&#34;&gt;飞桨启航计划集训营（第四期）&lt;/a&gt;，认领并完成 &lt;a href=&#34;https://github.com/PaddlePaddle/Paddle/issues/69639&#34;&gt;【开源任务】CINN编译器后端Pass改造&lt;/a&gt; 系列任务。趁最近在准备期末考试，除了复习干啥都有意思，好好总结一下在启航里的收获。（逃 🤐&lt;/p&gt;
&lt;h1 id=&#34;why-启航&#34;&gt;Why 启航？&lt;/h1&gt;
&lt;p&gt;为什么选择了启航计划？在回答这个问题之前，先介绍一下背景：当时学习了 CMU 10414 DLSys 课程，准备学习 TVM 或者 MLIR，但相关基础欠缺，一直苦于找不到切入口。在互联网上🏄‍♀️的时候无意中发现了启航计划，了解到其对新手相当友好：没有面试筛选、任务比较简单、有专门答疑研发老师，当时第三期正在进行，遂订阅了第三期的 ISSUE，蹲第四期的活动。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>在过去八周时间里，我参加了由飞桨开源社区组织的 <a href="https://github.com/PaddlePaddle/Paddle/issues/69152">飞桨启航计划集训营（第四期）</a>，认领并完成 <a href="https://github.com/PaddlePaddle/Paddle/issues/69639">【开源任务】CINN编译器后端Pass改造</a> 系列任务。趁最近在准备期末考试，除了复习干啥都有意思，好好总结一下在启航里的收获。（逃 🤐</p>
<h1 id="why-启航">Why 启航？</h1>
<p>为什么选择了启航计划？在回答这个问题之前，先介绍一下背景：当时学习了 CMU 10414 DLSys 课程，准备学习 TVM 或者 MLIR，但相关基础欠缺，一直苦于找不到切入口。在互联网上🏄‍♀️的时候无意中发现了启航计划，了解到其对新手相当友好：没有面试筛选、任务比较简单、有专门答疑研发老师，当时第三期正在进行，遂订阅了第三期的 ISSUE，蹲第四期的活动。</p>
<h1 id="启航计划安排">启航计划安排</h1>
<p>刚开始有三个打卡任务，分别是编译 Paddle、跑通 Paddle Mix 和 为 Paddle 添加文档。第一个任务用来熟悉本地编译 Paddle 和单测，第三个任务用来熟悉 GitHub 工作流程。</p>
<p>理论上，完成这三个任务就能够达到最低结营条件，但我们参加这个活动肯定不是为了这张结营证书，而是想要提升自己的。这三个任务对于提升自己的作用聊胜于无。下一步，就可以选择几个的专项团，尝试一些低星任务。</p>
<p>由于启航计划面向新手，任务比较简单。低星任务基本是照葫芦画瓢，即照着样例基本就能完成，通过低星任务可以理解这个专项团的总体目标。高星任务则是一些推广，或者逻辑比较复杂，但也基本不涉及从 0 到 1 的创作，本质上还是模仿。</p>
<p>当然，任务简单并不意味着可以很轻松地完成。对于我们这种零经验的开发者来说，极大概率需要花上几天时间才能理解“1+1=2”，后期还会发现理解是不完备的或者根本就是错的😭。在完成的过程中，可以反复阅读任务文档和观看任务讲解视频，多与导师沟通，很多时候他们都能一语点醒梦中人。特别感谢 <a href="https://github.com/Hongqing-work">Hongqing-work</a> 老师，CINN Pass 改造基本都是在向她请教，老师周末和晚上都能不厌其烦地答疑解惑，太感动了😭。</p>
<p>在训练营中，每两周都需要提交周报。这既是一个让我们回顾过去两周产出、规划未来的好机会，也能够了解其他同学的进度，保证自己不掉队。按照我的经验，1-2 周用于完成打卡任务，开始尝试低星任务；3-4 周继续完成某个专项团的任务，此时已经可以冲击一些高星任务了；5-8 周，渐臻佳境，对于某个专项团的任务已经能够做到游刃有余，并且尝试其它专项团任务。</p>
<h1 id="cinn-后端-pass-改造">CINN 后端 Pass 改造</h1>
<p>在本次启航计划中，我一共完成 7 个 CINN 后端 Pass 改造任务。这里介绍一下这个专项团的收获。</p>
<h2 id="背景">背景</h2>
<p>本次任务的背景是 CINN 升级了后端 IR 表示，将原来 Func-Expr 层级结构中的 Expr 进行了细化，重新划分为 Func-Block-Stmt-Expr，重新划分后的 IR 层次更加清晰。</p>
<p>与之对应地，后端 Pass 也被划分为 FuncPass、BlockPass、StmtPass、ExprPass 四个级别，使用配套的 PassManager 应用 Pass。其层次结构为：<br>
<img alt="新 IR 层次结构  图源：https://github.com/PaddlePaddle/Paddle/issues/69639" loading="lazy" src="https://pics.zhouxin.space/202501081247817.webp"></p>
<p>此外，还提供了 IR 访问方法：</p>
<ol>
<li>类型不敏感的 Stmt 和 Block 级别的访问/修改方法，在遍历 Stmt 前后将会调用用户传入的回调方法：</li>
</ol>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="c1">// Visitors
</span></span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Visit</span><span class="p">(</span><span class="k">const</span> <span class="n">BlockRef</span> <span class="o">&amp;</span><span class="n">block</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">pre_callback</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">post_callback</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Visit</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="n">stmt</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">pre_callback</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">post_callback</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="c1">// Mutators
</span></span></span><span class="line"><span class="cl"><span class="c1">// ...
</span></span></span></code></pre></td></tr></table>
</div>
</div><ol start="2">
<li>类型敏感的 Stmt 和 Block 定制化访问模板类，用户可以通过重写 <code>virtual StmtRetTy VisitStmt(const StmtRef &amp;stmt, Args... args)</code> 定制化访问不同的 Stmt：</li>
</ol>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">StmtRetTy</span> <span class="o">=</span> <span class="kt">void</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="k">typename</span> <span class="n">BlockRetTy</span> <span class="o">=</span> <span class="kt">void</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="k">typename</span><span class="p">...</span> <span class="n">Args</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">StmtVisitor</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"> <span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">  <span class="k">virtual</span> <span class="n">StmtRetTy</span> <span class="n">VisitStmt</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="n">stmt</span><span class="p">,</span> <span class="n">Args</span><span class="p">...</span> <span class="n">args</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CINN_CHECK_STMT_DEFINED</span><span class="p">(</span><span class="n">stmt</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">switch</span> <span class="p">(</span><span class="n">stmt</span><span class="o">-&gt;</span><span class="n">stmt_type</span><span class="p">())</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="cp">#define __(stmt__)                                \
</span></span></span><span class="line"><span class="cl"><span class="cp">  case ir::StmtNodeTy::stmt__:                    \
</span></span></span><span class="line"><span class="cl"><span class="cp">    return VisitStmt(stmt.as&lt;stmt__&gt;(), args...); \
</span></span></span><span class="line"><span class="cl"><span class="cp">    break;
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">      <span class="n">NODETY_FORALL_STMT</span><span class="p">(</span><span class="n">__</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">      <span class="k">default</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">PADDLE_THROW</span><span class="p">(</span><span class="o">::</span><span class="n">common</span><span class="o">::</span><span class="n">errors</span><span class="o">::</span><span class="n">InvalidArgument</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="s">&#34;Deadcode, not supported StmtNodeTy&#34;</span><span class="p">));</span>
</span></span><span class="line"><span class="cl"><span class="cp">#undef __
</span></span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="为什么要升级-ir">为什么要升级 IR？</h2>
<p>从后端 Pass 的角度来看，IR 升级主要有两个好处：1. Pass 编写更加清晰和规范；2. Pass 便于管理。</p>
<p>旧 IR 下的的 Pass 大都通过继承 IRMutator/Visitor 在遍历整个 IR 的过程中修改来实现 Pass 的功能，但实际上其只需要针对某个特定类型的 Stmt/Block 处理即可。旧 IR 下的 IRMutator 为了便于开发者使用，提供了对各种类型的 Expr/Stmt/Block 默认遍历，例如对于 IfThenElse 默认实现版本会遍历条件和两个分支：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">IRMutator</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="k">const</span> <span class="n">IfThenElse</span> <span class="o">*</span><span class="n">expr</span><span class="p">,</span> <span class="n">T</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="o">*</span><span class="n">node</span> <span class="o">=</span> <span class="n">op</span><span class="o">-&gt;</span><span class="k">template</span> <span class="n">As</span><span class="o">&lt;</span><span class="n">IfThenElse</span><span class="o">&gt;</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">IRVisitorRequireReImpl</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">condition</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">condition</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">IRVisitorRequireReImpl</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">true_case</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">true_case</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">false_case</span><span class="p">.</span><span class="n">defined</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">    <span class="n">IRVisitorRequireReImpl</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">false_case</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                           <span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">false_case</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>这种默认实现在很多情况下是不必要的，比如在合并两个相同的 If 中，显然不需要对条件应用此 Pass，也不需要对 Expr 级别的表达式进行访问。</p>
<p>理论上说，开发者可以通过重写对应的 Visit 方法来及时进行截断，但一方面这样会使得 Pass 的代码比较臃肿，另一方面 Pass 在开发时并没有此规范，已经成为遗留问题。</p>
<p>在此次 IR 和 Pass 改造后，原有的 IRMutator 将只保留对于 Expr 级别的访问逻辑，对于 Stmt 和 Block 级别的遍历由 PassManager 完成。例如，StmtPassManager 将会遍历这个函数，并为每一条 Stmt 调用一次其管理的 StmtPass，而在 StmtPass 内部，其只需要处理符合其目标的逻辑。</p>
<p>此外，新版的 StmtVisitor 没有提供 <code>VisitStmt</code> 默认实现，这可以强迫开发者自定义遍历逻辑，并及时截断不需要的遍历。</p>
<h2 id="pass-编写范式">Pass 编写范式</h2>
<p>升级后的 IR 的编写范式一般为：1. 继承对应级别的 Pass 基类；2. 使用一个内部类对 Func/Block/Stmt 进行遍历实现核心功能，这个类可以继承 StmtMutator/IRMutator 或者调用 Visit/Mutate 方法来实现遍历；3. 返回 Success。</p>
<ol>
<li>继承对应级别的 Pass 基类<br>
第一步就是分析原 Pass 是什么级别 Pass，核心要义是抓住原 Pass 需要什么级别的信息以及是什么级别的修改。例如：</li>
</ol>
<ul>
<li><a href="https://github.com/PaddlePaddle/Paddle/pull/69611">IfFusionPass</a> 是合并两个多个条件相同的 If，其要识别和删除多个 If，只有拿到这个 If 所在的 Block 能够实现多个语句的识别和单个语句的删除，这是一个 Block 级别的 Pass；</li>
<li><a href="https://github.com/PaddlePaddle/Paddle/pull/70437">RearrangeLoadInstruction</a> 是将一个函数内的 Load 放到最前面执行，以提高指令的并行并行程度，为了确保 Load 是<strong>一个函数内</strong>最先执行的语句，以及对<strong>函数内所有</strong>相同的 Load 替换为本地变量，这是一个 Func 级别的 Pass；</li>
<li><a href="https://github.com/PaddlePaddle/Paddle/pull/70619/files">EliminateCommonFactorOfLocalIndex</a> 需要获取当前 For 的<strong>嵌套信息</strong>，那必须由本 Pass 负责对 IR 的遍历，否则无法获取当前 For 的嵌套级别信息，因此这是一个 Func 级别的 Pass。</li>
</ul>
<p>我的经验是：如果一个 Pass 仅仅需要当前 Stmt 的内部信息、不需要删除或者替换当前 Stmt、并且对于当前 Stmt 的嵌套级别没有要求（例如不要求当前的 For 是最外层/最内层的 For），那么其是一个 Stmt 级别的 Pass；如果一个 Pass 需要跨语句的信息，或者需要删除/替换/添加一条 Stmt，那么其是一个 BlockPass；如果一个 Pass 需要自己控制对 IR 的遍历过程，或者需要当前的嵌套上下文，那么这是一个 Func 级别的 Pass。</p>
<ol start="2">
<li>编写实现类<br>
一些比较简单的 Pass 就是一个继承了 IRMutator 的实现类，此类 Pass 一般只需要额外继承 StmtMutator，如果不涉及 Expr 层面，则去掉对于 IRMutator，然后将原有逻辑迁移到新 IR 下即可。可参考 <a href="https://github.com/PaddlePaddle/Paddle/pull/70334">RemoveScheduleBlock</a>。新 IR 下，很多变量都被设置为私有变量，必须通过 getter 和 setter 进行读写。</li>
</ol>
<p>一些比较复杂的 Pass 可能有多个 Mutator 对 IR 进行多次访问，一般第一次是收集全局信息，之后再进行修改。读懂源码后再照葫芦画瓢修改即可。</p>
<p>更复杂的是调用了一些旧 IR 的方法，例如 <code>ir::ir_utils::CollectIRNodesWithoutTensor</code>，这种情况下可以判断一下传入的参数是否是 Expr，如果是 Expr 则还可以调用该方法（因为对 Expr 是封闭的，Expr 中不会有 Stmt 或者 Block），否则要根据这些方法的逻辑在新 IR 下进行实现。</p>
<ol start="3">
<li>返回 Success<br>
这个没啥好说，返回 <code>LogicalResult::success()</code> 即可。</li>
</ol>
<h2 id="tips">Tips</h2>
<ol>
<li>
<p>Pass 应该实现为无状态的<br>
无状态指的是 Pass 不应该依赖之前的信息，或者记录一些持久信息。例如，一个对于 For 进行处理的 Pass，其内部不应该记录当前 For 的名字以防止重复。如果想要避免重复访问，可以将其实现为 FuncPass 手动处理遍历逻辑。</p>
</li>
<li>
<p>PassManager 是按照 DFS 后序遍历的<br>
这一遍历顺序可以保证最内部的语句被最先访问。Pass 改造过程中是可以依赖这一行为的。</p>
</li>
<li>
<p>Pass 之间缺乏通信机制<br>
Pass 之间是缺乏通信机制的，一些 Pass 在应用是前是需要检查能否进行变换的，这些检查 Pass 可以作为变换 Pass 的内部的一部分，在变换 Pass 实例化一个 PassManager 应用检查的 Pass。</p>
</li>
<li>
<p>可参考 <a href="https://halide-lang.org/docs/">Halide文档</a><br>
CINN 在很多设计上参考了 Halide 和 TVM，在如果碰到一些例如不知道 Stmt 的作用的疑问，可以参考这两个这两个文档更加丰富的社区，往往会有惊喜收获。</p>
</li>
</ol>
<h1 id="后记">后记</h1>
<p>作为第一次开源活动经历，我个人觉得还是收获颇丰的。纸上得来终觉浅，绝知此事要躬行，很多之前没有实操的技术都在这次活动中得到了锻炼，例如 Git 和 GitHub 的工作流、VSCode 和 CMake 的配套、GLOG 的使用等等，以及对于 CINN 中 Pass 改造的经验，更是很好的学习 AI Sys 的切入口。</p>
<p>鼓励没有尝试过的同学多多参加这类活动，一定能不虚此行！</p>
]]></content:encoded>
    </item>
    <item>
      <title>百度飞桨「启航计划」小结——CINN后端Pass改造</title>
      <link>https://www.zhouxin.space/thoughts/baidu-paddlepaddle-starter-plan-summary/</link>
      <pubDate>Wed, 08 Jan 2025 00:07:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/thoughts/baidu-paddlepaddle-starter-plan-summary/</guid>
      <description>&lt;p&gt;在过去八周时间里，我参加了由飞桨开源社区组织的 &lt;a href=&#34;https://github.com/PaddlePaddle/Paddle/issues/69152&#34;&gt;飞桨启航计划集训营（第四期）&lt;/a&gt;，认领并完成 &lt;a href=&#34;https://github.com/PaddlePaddle/Paddle/issues/69639&#34;&gt;【开源任务】CINN编译器后端Pass改造&lt;/a&gt; 系列任务。趁最近在准备期末考试，除了复习干啥都有意思，好好总结一下在启航里的收获。（逃 🤐&lt;/p&gt;
&lt;h1 id=&#34;why-启航&#34;&gt;Why 启航？&lt;/h1&gt;
&lt;p&gt;为什么选择了启航计划？在回答这个问题之前，先介绍一下背景：当时学习了 CMU 10414 DLSys 课程，准备学习 TVM 或者 MLIR，但相关基础欠缺，一直苦于找不到切入口。在互联网上🏄‍♀️的时候无意中发现了启航计划，了解到其对新手相当友好：没有面试筛选、任务比较简单、有专门答疑研发老师，当时第三期正在进行，遂订阅了第三期的 ISSUE，蹲第四期的活动。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>在过去八周时间里，我参加了由飞桨开源社区组织的 <a href="https://github.com/PaddlePaddle/Paddle/issues/69152">飞桨启航计划集训营（第四期）</a>，认领并完成 <a href="https://github.com/PaddlePaddle/Paddle/issues/69639">【开源任务】CINN编译器后端Pass改造</a> 系列任务。趁最近在准备期末考试，除了复习干啥都有意思，好好总结一下在启航里的收获。（逃 🤐</p>
<h1 id="why-启航">Why 启航？</h1>
<p>为什么选择了启航计划？在回答这个问题之前，先介绍一下背景：当时学习了 CMU 10414 DLSys 课程，准备学习 TVM 或者 MLIR，但相关基础欠缺，一直苦于找不到切入口。在互联网上🏄‍♀️的时候无意中发现了启航计划，了解到其对新手相当友好：没有面试筛选、任务比较简单、有专门答疑研发老师，当时第三期正在进行，遂订阅了第三期的 ISSUE，蹲第四期的活动。</p>
<h1 id="启航计划安排">启航计划安排</h1>
<p>刚开始有三个打卡任务，分别是编译 Paddle、跑通 Paddle Mix 和 为 Paddle 添加文档。第一个任务用来熟悉本地编译 Paddle 和单测，第三个任务用来熟悉 GitHub 工作流程。</p>
<p>理论上，完成这三个任务就能够达到最低结营条件，但我们参加这个活动肯定不是为了这张结营证书，而是想要提升自己的。这三个任务对于提升自己的作用聊胜于无。下一步，就可以选择几个的专项团，尝试一些低星任务。</p>
<p>由于启航计划面向新手，任务比较简单。低星任务基本是照葫芦画瓢，即照着样例基本就能完成，通过低星任务可以理解这个专项团的总体目标。高星任务则是一些推广，或者逻辑比较复杂，但也基本不涉及从 0 到 1 的创作，本质上还是模仿。</p>
<p>当然，任务简单并不意味着可以很轻松地完成。对于我们这种零经验的开发者来说，极大概率需要花上几天时间才能理解“1+1=2”，后期还会发现理解是不完备的或者根本就是错的😭。在完成的过程中，可以反复阅读任务文档和观看任务讲解视频，多与导师沟通，很多时候他们都能一语点醒梦中人。特别感谢 <a href="https://github.com/Hongqing-work">Hongqing-work</a> 老师，CINN Pass 改造基本都是在向她请教，老师周末和晚上都能不厌其烦地答疑解惑，太感动了😭。</p>
<p>在训练营中，每两周都需要提交周报。这既是一个让我们回顾过去两周产出、规划未来的好机会，也能够了解其他同学的进度，保证自己不掉队。按照我的经验，1-2 周用于完成打卡任务，开始尝试低星任务；3-4 周继续完成某个专项团的任务，此时已经可以冲击一些高星任务了；5-8 周，渐臻佳境，对于某个专项团的任务已经能够做到游刃有余，并且尝试其它专项团任务。</p>
<h1 id="cinn-后端-pass-改造">CINN 后端 Pass 改造</h1>
<p>在本次启航计划中，我一共完成 7 个 CINN 后端 Pass 改造任务。这里介绍一下这个专项团的收获。</p>
<h2 id="背景">背景</h2>
<p>本次任务的背景是 CINN 升级了后端 IR 表示，将原来 Func-Expr 层级结构中的 Expr 进行了细化，重新划分为 Func-Block-Stmt-Expr，重新划分后的 IR 层次更加清晰。</p>
<p>与之对应地，后端 Pass 也被划分为 FuncPass、BlockPass、StmtPass、ExprPass 四个级别，使用配套的 PassManager 应用 Pass。其层次结构为：<br>
<img alt="新 IR 层次结构  图源：https://github.com/PaddlePaddle/Paddle/issues/69639" loading="lazy" src="https://pics.zhouxin.space/202501081247817.webp"></p>
<p>此外，还提供了 IR 访问方法：</p>
<ol>
<li>类型不敏感的 Stmt 和 Block 级别的访问/修改方法，在遍历 Stmt 前后将会调用用户传入的回调方法：</li>
</ol>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="c1">// Visitors
</span></span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Visit</span><span class="p">(</span><span class="k">const</span> <span class="n">BlockRef</span> <span class="o">&amp;</span><span class="n">block</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">pre_callback</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">post_callback</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Visit</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="n">stmt</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">pre_callback</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="p">)</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">post_callback</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="c1">// Mutators
</span></span></span><span class="line"><span class="cl"><span class="c1">// ...
</span></span></span></code></pre></td></tr></table>
</div>
</div><ol start="2">
<li>类型敏感的 Stmt 和 Block 定制化访问模板类，用户可以通过重写 <code>virtual StmtRetTy VisitStmt(const StmtRef &amp;stmt, Args... args)</code> 定制化访问不同的 Stmt：</li>
</ol>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">StmtRetTy</span> <span class="o">=</span> <span class="kt">void</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="k">typename</span> <span class="n">BlockRetTy</span> <span class="o">=</span> <span class="kt">void</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="k">typename</span><span class="p">...</span> <span class="n">Args</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">StmtVisitor</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"> <span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">  <span class="k">virtual</span> <span class="n">StmtRetTy</span> <span class="n">VisitStmt</span><span class="p">(</span><span class="k">const</span> <span class="n">StmtRef</span> <span class="o">&amp;</span><span class="n">stmt</span><span class="p">,</span> <span class="n">Args</span><span class="p">...</span> <span class="n">args</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CINN_CHECK_STMT_DEFINED</span><span class="p">(</span><span class="n">stmt</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">switch</span> <span class="p">(</span><span class="n">stmt</span><span class="o">-&gt;</span><span class="n">stmt_type</span><span class="p">())</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="cp">#define __(stmt__)                                \
</span></span></span><span class="line"><span class="cl"><span class="cp">  case ir::StmtNodeTy::stmt__:                    \
</span></span></span><span class="line"><span class="cl"><span class="cp">    return VisitStmt(stmt.as&lt;stmt__&gt;(), args...); \
</span></span></span><span class="line"><span class="cl"><span class="cp">    break;
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">      <span class="n">NODETY_FORALL_STMT</span><span class="p">(</span><span class="n">__</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">      <span class="k">default</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">PADDLE_THROW</span><span class="p">(</span><span class="o">::</span><span class="n">common</span><span class="o">::</span><span class="n">errors</span><span class="o">::</span><span class="n">InvalidArgument</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="s">&#34;Deadcode, not supported StmtNodeTy&#34;</span><span class="p">));</span>
</span></span><span class="line"><span class="cl"><span class="cp">#undef __
</span></span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="为什么要升级-ir">为什么要升级 IR？</h2>
<p>从后端 Pass 的角度来看，IR 升级主要有两个好处：1. Pass 编写更加清晰和规范；2. Pass 便于管理。</p>
<p>旧 IR 下的的 Pass 大都通过继承 IRMutator/Visitor 在遍历整个 IR 的过程中修改来实现 Pass 的功能，但实际上其只需要针对某个特定类型的 Stmt/Block 处理即可。旧 IR 下的 IRMutator 为了便于开发者使用，提供了对各种类型的 Expr/Stmt/Block 默认遍历，例如对于 IfThenElse 默认实现版本会遍历条件和两个分支：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">IRMutator</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="k">const</span> <span class="n">IfThenElse</span> <span class="o">*</span><span class="n">expr</span><span class="p">,</span> <span class="n">T</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="o">*</span><span class="n">node</span> <span class="o">=</span> <span class="n">op</span><span class="o">-&gt;</span><span class="k">template</span> <span class="n">As</span><span class="o">&lt;</span><span class="n">IfThenElse</span><span class="o">&gt;</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">IRVisitorRequireReImpl</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">condition</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">condition</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">IRVisitorRequireReImpl</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">true_case</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">true_case</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">false_case</span><span class="p">.</span><span class="n">defined</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">    <span class="n">IRVisitorRequireReImpl</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">,</span> <span class="n">T</span><span class="o">&gt;::</span><span class="n">Visit</span><span class="p">(</span><span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">false_case</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                           <span class="o">&amp;</span><span class="n">node</span><span class="o">-&gt;</span><span class="n">false_case</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>这种默认实现在很多情况下是不必要的，比如在合并两个相同的 If 中，显然不需要对条件应用此 Pass，也不需要对 Expr 级别的表达式进行访问。</p>
<p>理论上说，开发者可以通过重写对应的 Visit 方法来及时进行截断，但一方面这样会使得 Pass 的代码比较臃肿，另一方面 Pass 在开发时并没有此规范，已经成为遗留问题。</p>
<p>在此次 IR 和 Pass 改造后，原有的 IRMutator 将只保留对于 Expr 级别的访问逻辑，对于 Stmt 和 Block 级别的遍历由 PassManager 完成。例如，StmtPassManager 将会遍历这个函数，并为每一条 Stmt 调用一次其管理的 StmtPass，而在 StmtPass 内部，其只需要处理符合其目标的逻辑。</p>
<p>此外，新版的 StmtVisitor 没有提供 <code>VisitStmt</code> 默认实现，这可以强迫开发者自定义遍历逻辑，并及时截断不需要的遍历。</p>
<h2 id="pass-编写范式">Pass 编写范式</h2>
<p>升级后的 IR 的编写范式一般为：1. 继承对应级别的 Pass 基类；2. 使用一个内部类对 Func/Block/Stmt 进行遍历实现核心功能，这个类可以继承 StmtMutator/IRMutator 或者调用 Visit/Mutate 方法来实现遍历；3. 返回 Success。</p>
<ol>
<li>继承对应级别的 Pass 基类<br>
第一步就是分析原 Pass 是什么级别 Pass，核心要义是抓住原 Pass 需要什么级别的信息以及是什么级别的修改。例如：</li>
</ol>
<ul>
<li><a href="https://github.com/PaddlePaddle/Paddle/pull/69611">IfFusionPass</a> 是合并两个多个条件相同的 If，其要识别和删除多个 If，只有拿到这个 If 所在的 Block 能够实现多个语句的识别和单个语句的删除，这是一个 Block 级别的 Pass；</li>
<li><a href="https://github.com/PaddlePaddle/Paddle/pull/70437">RearrangeLoadInstruction</a> 是将一个函数内的 Load 放到最前面执行，以提高指令的并行并行程度，为了确保 Load 是<strong>一个函数内</strong>最先执行的语句，以及对<strong>函数内所有</strong>相同的 Load 替换为本地变量，这是一个 Func 级别的 Pass；</li>
<li><a href="https://github.com/PaddlePaddle/Paddle/pull/70619/files">EliminateCommonFactorOfLocalIndex</a> 需要获取当前 For 的<strong>嵌套信息</strong>，那必须由本 Pass 负责对 IR 的遍历，否则无法获取当前 For 的嵌套级别信息，因此这是一个 Func 级别的 Pass。</li>
</ul>
<p>我的经验是：如果一个 Pass 仅仅需要当前 Stmt 的内部信息、不需要删除或者替换当前 Stmt、并且对于当前 Stmt 的嵌套级别没有要求（例如不要求当前的 For 是最外层/最内层的 For），那么其是一个 Stmt 级别的 Pass；如果一个 Pass 需要跨语句的信息，或者需要删除/替换/添加一条 Stmt，那么其是一个 BlockPass；如果一个 Pass 需要自己控制对 IR 的遍历过程，或者需要当前的嵌套上下文，那么这是一个 Func 级别的 Pass。</p>
<ol start="2">
<li>编写实现类<br>
一些比较简单的 Pass 就是一个继承了 IRMutator 的实现类，此类 Pass 一般只需要额外继承 StmtMutator，如果不涉及 Expr 层面，则去掉对于 IRMutator，然后将原有逻辑迁移到新 IR 下即可。可参考 <a href="https://github.com/PaddlePaddle/Paddle/pull/70334">RemoveScheduleBlock</a>。新 IR 下，很多变量都被设置为私有变量，必须通过 getter 和 setter 进行读写。</li>
</ol>
<p>一些比较复杂的 Pass 可能有多个 Mutator 对 IR 进行多次访问，一般第一次是收集全局信息，之后再进行修改。读懂源码后再照葫芦画瓢修改即可。</p>
<p>更复杂的是调用了一些旧 IR 的方法，例如 <code>ir::ir_utils::CollectIRNodesWithoutTensor</code>，这种情况下可以判断一下传入的参数是否是 Expr，如果是 Expr 则还可以调用该方法（因为对 Expr 是封闭的，Expr 中不会有 Stmt 或者 Block），否则要根据这些方法的逻辑在新 IR 下进行实现。</p>
<ol start="3">
<li>返回 Success<br>
这个没啥好说，返回 <code>LogicalResult::success()</code> 即可。</li>
</ol>
<h2 id="tips">Tips</h2>
<ol>
<li>
<p>Pass 应该实现为无状态的<br>
无状态指的是 Pass 不应该依赖之前的信息，或者记录一些持久信息。例如，一个对于 For 进行处理的 Pass，其内部不应该记录当前 For 的名字以防止重复。如果想要避免重复访问，可以将其实现为 FuncPass 手动处理遍历逻辑。</p>
</li>
<li>
<p>PassManager 是按照 DFS 后序遍历的<br>
这一遍历顺序可以保证最内部的语句被最先访问。Pass 改造过程中是可以依赖这一行为的。</p>
</li>
<li>
<p>Pass 之间缺乏通信机制<br>
Pass 之间是缺乏通信机制的，一些 Pass 在应用是前是需要检查能否进行变换的，这些检查 Pass 可以作为变换 Pass 的内部的一部分，在变换 Pass 实例化一个 PassManager 应用检查的 Pass。</p>
</li>
<li>
<p>可参考 <a href="https://halide-lang.org/docs/">Halide文档</a><br>
CINN 在很多设计上参考了 Halide 和 TVM，在如果碰到一些例如不知道 Stmt 的作用的疑问，可以参考这两个这两个文档更加丰富的社区，往往会有惊喜收获。</p>
</li>
</ol>
<h1 id="后记">后记</h1>
<p>作为第一次开源活动经历，我个人觉得还是收获颇丰的。纸上得来终觉浅，绝知此事要躬行，很多之前没有实操的技术都在这次活动中得到了锻炼，例如 Git 和 GitHub 的工作流、VSCode 和 CMake 的配套、GLOG 的使用等等，以及对于 CINN 中 Pass 改造的经验，更是很好的学习 AI Sys 的切入口。</p>
<p>墙裂安利没有尝试过的同学多多参加这类活动，一定能不虚此行！</p>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML Lab 1 实验笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-lab-1/</link>
      <pubDate>Wed, 27 Nov 2024 14:53:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-lab-1/</guid>
      <description>&lt;h1 id=&#34;实验准备&#34;&gt;实验准备&lt;/h1&gt;
&lt;h2 id=&#34;python-环境&#34;&gt;Python 环境&lt;/h2&gt;
&lt;p&gt;需要用到如下 Python 环境：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;PyTorch GPU 版本&lt;/li&gt;
&lt;li&gt;jupyter notebook&lt;/li&gt;
&lt;li&gt;tqdm&lt;/li&gt;
&lt;li&gt;matplotlib&lt;/li&gt;
&lt;li&gt;torchprofile&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;数据集准备&#34;&gt;数据集准备&lt;/h2&gt;
&lt;p&gt;Lab 1 中用到了 CIFAR-10 数据集，可以使用 &lt;a href=&#34;https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz&#34;&gt;https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz&lt;/a&gt; 直接下载，并将整个 &lt;code&gt;cifar-10-batched-py&lt;/code&gt; 文件夹解压到 &lt;code&gt;data/cifar10&lt;/code&gt; 文件夹内。&lt;/p&gt;
&lt;h1 id=&#34;part-1-fine-grained-pruning&#34;&gt;Part 1: Fine-grained Pruning&lt;/h1&gt;
&lt;h2 id=&#34;question-1&#34;&gt;Question 1&lt;/h2&gt;
&lt;p&gt;&lt;img alt=&#34;各层权重分布直方图&#34; loading=&#34;lazy&#34; src=&#34;https://pics.zhouxin.space/20241127193237.png&#34;&gt;&lt;/p&gt;
&lt;p&gt;除最后一层分类头外，其它层均服从均值为 0 的无偏正态分布，这意味着占很大比例的参数是可以被移除的，这为模型压缩留下了很大的空间。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="实验准备">实验准备</h1>
<h2 id="python-环境">Python 环境</h2>
<p>需要用到如下 Python 环境：</p>
<ul>
<li>PyTorch GPU 版本</li>
<li>jupyter notebook</li>
<li>tqdm</li>
<li>matplotlib</li>
<li>torchprofile</li>
</ul>
<h2 id="数据集准备">数据集准备</h2>
<p>Lab 1 中用到了 CIFAR-10 数据集，可以使用 <a href="https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz">https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz</a> 直接下载，并将整个 <code>cifar-10-batched-py</code> 文件夹解压到 <code>data/cifar10</code> 文件夹内。</p>
<h1 id="part-1-fine-grained-pruning">Part 1: Fine-grained Pruning</h1>
<h2 id="question-1">Question 1</h2>
<p><img alt="各层权重分布直方图" loading="lazy" src="https://pics.zhouxin.space/20241127193237.png"></p>
<p>除最后一层分类头外，其它层均服从均值为 0 的无偏正态分布，这意味着占很大比例的参数是可以被移除的，这为模型压缩留下了很大的空间。</p>
<h2 id="question-2">Question 2</h2>
<p>第二个问题要求实现细粒度剪枝，即可以对权重矩阵中的单个元素进行剪枝，关于不同颗粒度的剪枝介绍，见 <a href="https://www.zhouxin.space/notes/notes-on-mit-efficientml-3rd-lecture/#%E5%89%AA%E6%9E%9D%E7%BB%86%E7%B2%92%E7%A8%8B%E5%BA%A6">课程第三讲笔记</a>。</p>
<p>这里使用每个参数的绝对值来表示其重要性，剪掉不重要的参数，保留重要的参数。</p>
<p>本问比较简单，根据稀疏度计算出需要剪去的参数总量，然后使用找到阈值并根据阈值得到 mask 矩阵。唯一的一个注意点是计算 mask 矩阵是使用大于而不是大于等于，这是由于计算得到的阈值也需要被剪掉。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1">##################### YOUR CODE STARTS HERE #####################</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Step 1: calculate the #zeros (please use round())</span>
</span></span><span class="line"><span class="cl"><span class="n">num_zeros</span> <span class="o">=</span> <span class="nb">round</span><span class="p">(</span><span class="n">sparsity</span> <span class="o">*</span> <span class="n">num_elements</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Step 2: calculate the importance of weight</span>
</span></span><span class="line"><span class="cl"><span class="n">importance</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">abs</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Step 3: calculate the pruning threshold</span>
</span></span><span class="line"><span class="cl"><span class="n">threshold</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">kthvalue</span><span class="p">(</span><span class="n">importance</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_zeros</span><span class="p">)</span><span class="o">.</span><span class="n">values</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Step 4: get binary mask (1 for nonzeros, 0 for zeros)</span>
</span></span><span class="line"><span class="cl"><span class="n">mask</span> <span class="o">=</span> <span class="n">importance</span> <span class="o">&gt;</span> <span class="n">threshold</span>
</span></span><span class="line"><span class="cl"><span class="c1">##################### YOUR CODE ENDS HERE #######################</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="question-3">Question 3</h2>
<p>问题三要求我们在一个 5 x 5 的矩阵中保留 10 个元素，相应的稀疏度为 $1-\frac{10}{25}$，此问就算结束了。</p>
<h2 id="question-4">Question 4</h2>
<p>此问对 VGG 网络每一层进行了灵敏度分析，建议将步长修改为 0.2 或者 0.1，以获得更加平滑的灵敏度曲线。</p>
<p><img alt="VGG 各层灵敏度分析结果" loading="lazy" src="https://pics.zhouxin.space/20241127203439.png"></p>
<p>从图中可以看到大部分层中，随着稀疏度的增加，模型精度相应变低，不同层的敏感程度不同，第 0 个卷积层对稀疏度最敏感。</p>
<h2 id="question-5">Question 5</h2>
<p>第 5 问中，要求根据前面灵敏度分析结果和模型参数计算量，设置每一层剪枝时的稀疏度。❗️注意，最终整个模型的稀疏度很大程度上取决于参数量比较大的层的稀疏度，对于参数量比较大的层，可以考虑设置比较高的稀疏度。</p>
<p>我选择的稀疏度参数为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">sparsity_dict</span> <span class="o">=</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="c1">##################### YOUR CODE STARTS HERE #####################</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># please modify the sparsity value of each layer</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># please DO NOT modify the key of sparsity_dict</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv0.weight&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv1.weight&#39;</span><span class="p">:</span> <span class="mf">0.6</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv2.weight&#39;</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv3.weight&#39;</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv4.weight&#39;</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv5.weight&#39;</span><span class="p">:</span> <span class="mf">0.6</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv6.weight&#39;</span><span class="p">:</span> <span class="mf">0.6</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;backbone.conv7.weight&#39;</span><span class="p">:</span> <span class="mf">0.75</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;classifier.weight&#39;</span><span class="p">:</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl"><span class="c1">##################### YOUR CODE ENDS HERE #######################</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>经过剪枝后，大小约为原始稠密模型的 38.48%，精度从 92.9% 降低到了 91.50%，在 5 轮的微调后，模型精度恢复为 92.95%。</p>
<h1 id="part-2-channel-pruning">Part 2: Channel Pruning</h1>
<h2 id="question-6">Question 6</h2>
<p>第 6 问需要实现 Channel Pruning，剪枝标准是只保留前 k 个通道。问题本身时简单的，用好 Python 中的切片即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">get_num_channels_to_keep</span><span class="p">(</span><span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">prune_ratio</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;A function to calculate the number of layers to PRESERVE after pruning
</span></span></span><span class="line"><span class="cl"><span class="s2">    Note that preserve_rate = 1. - prune_ratio
</span></span></span><span class="line"><span class="cl"><span class="s2">    &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="c1">##################### YOUR CODE STARTS HERE #####################</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">channels</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">prune_ratio</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">    <span class="c1">##################### YOUR CODE ENDS HERE #####################</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@torch.no_grad</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">channel_prune</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                  <span class="n">prune_ratio</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Apply channel pruning to each of the conv layer in the backbone
</span></span></span><span class="line"><span class="cl"><span class="s2">    Note that for prune_ratio, we can either provide a floating-point number,
</span></span></span><span class="line"><span class="cl"><span class="s2">    indicating that we use a uniform pruning rate for all layers, or a list of
</span></span></span><span class="line"><span class="cl"><span class="s2">    numbers to indicate per-layer pruning rate.
</span></span></span><span class="line"><span class="cl"><span class="s2">    &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># sanity check of provided prune_ratio</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prune_ratio</span><span class="p">,</span> <span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">list</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="n">n_conv</span> <span class="o">=</span> <span class="nb">len</span><span class="p">([</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">backbone</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">)])</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># note that for the ratios, it affects the previous conv output and next</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># conv input, i.e., conv0 - ratio0 - conv1 - ratio1-...</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prune_ratio</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prune_ratio</span><span class="p">)</span> <span class="o">==</span> <span class="n">n_conv</span> <span class="o">-</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">    <span class="k">else</span><span class="p">:</span>  <span class="c1"># convert float to list</span>
</span></span><span class="line"><span class="cl">        <span class="n">prune_ratio</span> <span class="o">=</span> <span class="p">[</span><span class="n">prune_ratio</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">n_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># we prune the convs in the backbone with a uniform ratio</span>
</span></span><span class="line"><span class="cl">    <span class="n">model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>  <span class="c1"># prevent overwrite</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># we only apply pruning to the backbone features</span>
</span></span><span class="line"><span class="cl">    <span class="n">all_convs</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">backbone</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    <span class="n">all_bns</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">backbone</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># apply pruning. we naively keep the first k channels</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">all_convs</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">all_bns</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i_ratio</span><span class="p">,</span> <span class="n">p_ratio</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">prune_ratio</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_conv</span> <span class="o">=</span> <span class="n">all_convs</span><span class="p">[</span><span class="n">i_ratio</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_bn</span> <span class="o">=</span> <span class="n">all_bns</span><span class="p">[</span><span class="n">i_ratio</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">next_conv</span> <span class="o">=</span> <span class="n">all_convs</span><span class="p">[</span><span class="n">i_ratio</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">original_channels</span> <span class="o">=</span> <span class="n">prev_conv</span><span class="o">.</span><span class="n">out_channels</span>  <span class="c1"># same as next_conv.in_channels</span>
</span></span><span class="line"><span class="cl">        <span class="n">n_keep</span> <span class="o">=</span> <span class="n">get_num_channels_to_keep</span><span class="p">(</span><span class="n">original_channels</span><span class="p">,</span> <span class="n">p_ratio</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1"># prune the output of the previous conv and bn</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">set_</span><span class="p">(</span><span class="n">prev_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()[:</span><span class="n">n_keep</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_bn</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">set_</span><span class="p">(</span><span class="n">prev_bn</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()[:</span><span class="n">n_keep</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_bn</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">set_</span><span class="p">(</span><span class="n">prev_bn</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()[:</span><span class="n">n_keep</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_bn</span><span class="o">.</span><span class="n">running_mean</span><span class="o">.</span><span class="n">set_</span><span class="p">(</span><span class="n">prev_bn</span><span class="o">.</span><span class="n">running_mean</span><span class="o">.</span><span class="n">detach</span><span class="p">()[:</span><span class="n">n_keep</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_bn</span><span class="o">.</span><span class="n">running_var</span><span class="o">.</span><span class="n">set_</span><span class="p">(</span><span class="n">prev_bn</span><span class="o">.</span><span class="n">running_var</span><span class="o">.</span><span class="n">detach</span><span class="p">()[:</span><span class="n">n_keep</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1"># prune the input of the next conv (hint: just one line of code)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">##################### YOUR CODE STARTS HERE #####################</span>
</span></span><span class="line"><span class="cl">        <span class="n">next_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">set_</span><span class="p">(</span><span class="n">next_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()[:,</span> <span class="p">:</span><span class="n">n_keep</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="c1">##################### YOUR CODE ENDS HERE #####################</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">model</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>记得一提的是框架已经给出的代码，所谓 Channel 是在卷积中才会出现的，剪枝也是对输出通道进行剪枝。例如，当前卷积核中本来有 k 个通道输出，剪枝后变成 l 个输出通道，那么下一层的卷积核的输入通道也要相对应地从 k 变成 l。此外一般 Conv 后都会有一个 Batch Norm，应该这个 Conv 的 weight、bias、running_mean 和 running_var 也要一起进行剪枝。</p>
<h2 id="question-7">Question 7</h2>
<p>改进 Channel Pruning，使用 Frobenius 范数来评估每一个通道的重要程度。本问的核心就是 Frobenius 范数的计算，说明中推荐使用 <a href="https://pytorch.org/docs/main/generated/torch.norm.html#torch.norm">torch.norm</a> 进行实现，但是官网文档中提到这个 API 已经被弃用，这里改用 <a href="https://pytorch.org/docs/main/generated/torch.linalg.vector_norm.html#torch.linalg.vector_norm" title="torch.linalg.vector_norm">torch.linalg.vector_norm()</a>。根据文档，<code>dim</code> 指定为需要展开为向量的维度，即第 <code>[0, 2, 3]</code>。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1"># function to sort the channels from important to non-important</span>
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">get_input_channel_importance</span><span class="p">(</span><span class="n">weight</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">in_channels</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># importances = []</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># # compute the importance for each input channel</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># for i_c in range(weight.shape[1]):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#     channel_weight = weight.detach()[:, i_c]</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#     ##################### YOUR CODE STARTS HERE #####################</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#     importance = torch.linalg.norm(channel_weight, ord=&#34;fro&#34;, dim</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#     ##################### YOUR CODE ENDS HERE #####################</span>
</span></span><span class="line"><span class="cl">    <span class="c1">#     importances.append(importance.view(1))</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># return torch.cat(importances)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">vector_norm</span><span class="p">(</span><span class="n">weight</span><span class="p">,</span> <span class="nb">ord</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@torch.no_grad</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">apply_channel_sorting</span><span class="p">(</span><span class="n">model</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>  <span class="c1"># do not modify the original model</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># fetch all the conv and bn layers from the backbone</span>
</span></span><span class="line"><span class="cl">    <span class="n">all_convs</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">backbone</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    <span class="n">all_bns</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">backbone</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># iterate through conv layers</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i_conv</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">all_convs</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># each channel sorting index, we need to apply it to:</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># - the output dimension of the previous conv</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># - the previous BN layer</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># - the input dimension of the next conv (we compute importance here)</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_conv</span> <span class="o">=</span> <span class="n">all_convs</span><span class="p">[</span><span class="n">i_conv</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_bn</span> <span class="o">=</span> <span class="n">all_bns</span><span class="p">[</span><span class="n">i_conv</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">next_conv</span> <span class="o">=</span> <span class="n">all_convs</span><span class="p">[</span><span class="n">i_conv</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># note that we always compute the importance according to input channels</span>
</span></span><span class="line"><span class="cl">        <span class="n">importance</span> <span class="o">=</span> <span class="n">get_input_channel_importance</span><span class="p">(</span><span class="n">next_conv</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># sorting from large to small</span>
</span></span><span class="line"><span class="cl">        <span class="n">sort_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">importance</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1"># apply to previous conv and its following bn</span>
</span></span><span class="line"><span class="cl">        <span class="n">prev_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">prev_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span> <span class="mi">0</span><span class="p">,</span> <span class="n">sort_idx</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">tensor_name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;weight&#39;</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">,</span> <span class="s1">&#39;running_mean&#39;</span><span class="p">,</span> <span class="s1">&#39;running_var&#39;</span><span class="p">]:</span>
</span></span><span class="line"><span class="cl">            <span class="n">tensor_to_apply</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">prev_bn</span><span class="p">,</span> <span class="n">tensor_name</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">tensor_to_apply</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span><span class="n">tensor_to_apply</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span> <span class="mi">0</span><span class="p">,</span> <span class="n">sort_idx</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1"># apply to the next conv input (hint: one line of code)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">##################### YOUR CODE STARTS HERE #####################</span>
</span></span><span class="line"><span class="cl">        <span class="n">next_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span><span class="n">next_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">sort_idx</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">##################### YOUR CODE ENDS HERE #####################</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">model</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>相比没有计算重要性的通道剪枝，改进剪枝后的模型的准确率从 28.15% 提升到 36.81。经过微调后恢复为 92.41%。</p>
<h2 id="question-8">Question 8</h2>
<ol>
<li>为什么剪枝 30% 但是计算量减少了大约 50%。<br>
VGG 模型主要由卷积层构成，卷积层的计算量 FLOPs 为：</li>
</ol>


<div>$$

FLOPs = K\times K\times C_{in}\times C_{out}\times H \times W

$$</div>

<p>其中输入和输出通道都变为原来的 70%，因而总计算量变为原来的 49%。</p>
<ol start="2">
<li>解释一下为什么延迟（latency）的减少比例略小于计算量的减少比例。<br>
延迟不仅仅来源于计算，还来自于数据的搬运，这部分时间在没做算子融合的情况下减少并不显著。</li>
</ol>
<h2 id="question-9">Question 9</h2>
<ol>
<li>
<p>讨论一下 fine-grained pruning 和 channel pruning 的优缺点。<br>
细粒度剪枝：压缩率更高、对硬件不友好、延迟高；<br>
通道剪枝：压缩率低、硬件友好、延迟低、易于微调。</p>
</li>
<li>
<p>如果想在智能手机上加速模型，使用哪种方案更合适。<br>
通道剪枝。智能手机上一般缺乏对于稀疏矩阵的支持，选取对硬件更友好的方案。</p>
</li>
</ol>
<h1 id="小结">小结</h1>
<p>第一个 Lab 本身比较简单，做完能够建立起对于剪枝的初步认识，希望后面的实验能够上点强度，代码量也太少了😂。</p>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML 第五讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-5th-lecture/</link>
      <pubDate>Mon, 18 Nov 2024 09:34:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-5th-lecture/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;本讲开始介绍量化技术，首先介绍各种数据表示格式，然后介绍了两种量化技术：K-means和线性量化，最后提到了模型压缩的流水线。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id=&#34;数值数据类型&#34;&gt;数值数据类型&lt;/h1&gt;
&lt;p&gt;课程第一部分介绍了整型、定点小数、浮点数的数据表示格式，属于计算机组成原理的基本知识，此处不再赘述。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>本讲开始介绍量化技术，首先介绍各种数据表示格式，然后介绍了两种量化技术：K-means和线性量化，最后提到了模型压缩的流水线。</p>
</blockquote>
<h1 id="数值数据类型">数值数据类型</h1>
<p>课程第一部分介绍了整型、定点小数、浮点数的数据表示格式，属于计算机组成原理的基本知识，此处不再赘述。</p>
<p>值得一提的是，在浮点数表示法中，阶码长度意味着该表示法能够表示数值的范围，尾数长度则决定该表示法能够表示的数据精度。如下图所示，Google 提出了一种 BF16 表示格式，其总共占用 2 字节，但是与 IEEE 754 格式中的 4 字节的单精度浮点数具有相同的阶码长度，从而对齐二者的表示范围。<br>
<img alt="Google BF16 表示格式" loading="lazy" src="https://pics.zhouxin.space/202411180953293.webp"></p>
<p>NVIDIA 提出了两种不同的 8 位浮点数的表示格式，它们具有不同的阶数与尾数的长度：<br>
<img alt="NVIDIA 提出的两种 FP8 表示格式" loading="lazy" src="https://pics.zhouxin.space/202411180959315.webp"></p>
<p>而如果仅仅使用 4 比特来表示整型或者浮点型，那能够表示出的数可以简单在数轴上点出：<br>
<img alt="INT4 和 FP4 数据表示格式与表示范围" loading="lazy" src="https://pics.zhouxin.space/202411181004629.webp"></p>
<h1 id="量化">量化</h1>
<h2 id="什么是量化">什么是量化</h2>
<p>量化指的是将连续值或者大量的可能的离散取值近似为少量取值的过程。例如下图展示了对连续信号和对高清图片的量化效果。<br>
<img alt="量化示意" loading="lazy" src="https://pics.zhouxin.space/202411181014481.webp"></p>
<p>量化前后的差别被定义为量化误差，量化过程需要最小化这一误差。</p>
<h2 id="k-means-量化">K-means 量化</h2>
<p>如下图所示，基于 K-means 的量化算法通过 K-means 算法对所有权重进行聚类，每个权重使用其聚类后所在簇的簇中心值作为量化后的值，并在权重中使用簇号来指代。<br>
<img alt="K-means 量化示意图" loading="lazy" src="https://pics.zhouxin.space/202411261948134.webp"><br>
假设被量化为 $2^n$ 个簇，即每个权重需要使用 $n$ 比特整型表示其簇号，原始权重共有 $m$ 个参数，且 $m&raquo;2^n$。在量化前权重矩阵内存占用为 $32m$ 比特，量化后为 $nm+32\times 2^n$ 比特。考虑到 $m&raquo;2^n$，量化后近似为 $nm$ 比特，量化后内存相当于量化前的 $n/32$。</p>
<p>量化后的模型通过微调可以取的更好的效果，K-means 量化的微调算法将同一簇的梯度相加，作为这个簇共同的梯度，并进行权重更新。<br>
<img alt="K-means 量化的微调算法" loading="lazy" src="https://pics.zhouxin.space/202411262008697.webp"></p>
<p>通过结合剪枝和量化两个策略，可以实现 20x 的模型压缩。<br>
<img alt="K-means 量化性能" loading="lazy" src="https://pics.zhouxin.space/202411262010786.webp"></p>
<p>一般来说，先做剪枝，再进行量化。通过剪枝可以只保留有用的参数，减少量化的参数量。</p>
<p>量化位宽一般取决于不同的模型，如下图所示，一般卷积层需要 4 比特，全连接层仅需要 2 比特。<br>
<img alt="量化位宽与精度损失之间关系" loading="lazy" src="https://pics.zhouxin.space/202411270935213.webp"></p>
<p>此外，还可以使用哈夫曼编码来进一步压缩模型的大小，为频数高的参数使用较短的编码。</p>
<h2 id="线性量化">线性量化</h2>
<p>线性量化利用了一种将低精度整型通过仿射变换转换为浮点数的技术，如下图通过仿射变换可以将量化后的 2 比特权重张量近似还原为原始权重张量。</p>
<p><img alt="线性量化" loading="lazy" src="https://pics.zhouxin.space/202411270951609.webp"></p>
<p>上述变换中有两个参数需要确定，一个是整型参数零点 $Z$ 和一个浮点型参数放缩因数 $S$，放射变换过程可以使用公式表示为：</p>


<div>$$

r = (q-Z)\times S

$$</div>

<p>其中，q 为量化后权重矩阵的值，r 为经过还原后权重矩阵的值。其中，量化矩阵中恰好为 Z 的元素将被还原为 0，而 S 决定了还原后的范围。</p>
<p>记 $r^\prime$ 表示权重矩阵的真实值，那真实参数表示的范围为 $r_{max}^\prime-r_{min}^\prime$，而经过还原的参数表示范围为 $(q_{max}-q_{max})\times S$，让二者相等可以得到放缩因数的计算公式：</p>


<div>$$

S = \frac{r_{max}^\prime-r_{min}^\prime}{q_{max}-q_{max}}

$$</div>

<p>在课程中，计算 $Z$ 的方法是让原始矩阵和还原后的权重矩阵的最小值对齐，即 $r_{min}^\prime = (q_{min} - Z)\times S$，从而推导出零点的计算公式：</p>


<div>$$

Z = \text{round}(q_{min} - \frac{r_{min}^\prime}{S})

$$</div>

<blockquote>
<p>这里有点奇怪，为什么不是让整个权重矩阵的总误差最小。</p>
</blockquote>
<h2 id="使用线性量化计算矩乘">使用线性量化计算矩乘</h2>
<p>在矩乘计算 $\mathbf{Y} = \mathbf{WX}$ 中，三个矩阵均使用线性量化进行表示，通过恒等变换可以得到：</p>


<div>$$

\begin{align*}
\mathbf{Y} &amp;= \mathbf{WX} \\
S_Y \left( \mathbf{q_Y} - Z_Y \right) &amp;= S_W \left( \mathbf{q_W} - Z_W \right) \cdot S_X \left( \mathbf{q_X} - Z_X \right) \\
\mathbf{q_Y} &amp;= \frac{S_W S_X}{S_Y} \left( \mathbf{q_W} - Z_W \right) \left( \mathbf{q_X} - Z_X \right) &#43; Z_Y \\
\mathbf{q_Y} &amp;= \frac{S_W S_X}{S_Y} \left( \mathbf{q_W q_X} - Z_W \mathbf{q_X} - Z_X \mathbf{q_W} &#43; Z_W Z_X \right) &#43; Z_Y
\end{align*}

$$</div>

<p>其中，与 $W$ 相关的变量都是常量，此外 $Z_W$ 也是常量（why???)，因此 $Z_X \mathbf{q_W} + Z_W Z_X$ 可以在量化时预先计算得到，在运行时没有计算开销。</p>
<p>对于放缩因子 $\frac{S_W S_X}{S_Y}$，经验上可以确定其范围在 $(0,1)$ 之间，因此，其可以表示为一个定点小数通过右移运算得到，如下图所示。因此只需要存储一个定点小数和右移次数，而不需要存储一个高精度浮点数。<br>
<img alt="放缩因子存储示意" loading="lazy" src="https://pics.zhouxin.space/202411271051921.webp"></p>
<p>考虑到权重一般是关于 0 对称，因此可以有理由假定 $Z_W=0$， 从而消去 $Z_W\mathbf{q_X}$ 和 $Z_W Z_X$ 这两项，从而计算公式变形为：</p>


<div>$$

\mathbf{q_Y} = \frac{S_W S_X}{S_Y} \left( \mathbf{q_W q_X}  - Z_X \mathbf{q_W} \right) &#43; Z_Y

$$</div>

<p>主要计算开销在于 $\mathbf{q_w q_x}$ 这一低精度整数矩乘计算。</p>
<h1 id="模型压缩的流水线">模型压缩的流水线</h1>
<p>模型压缩全流程如下所示，首先通过剪枝和微调来减少无效参数，然后通过量化来减少模型参数表示，最后使用哈夫曼编码进一步压缩模型大小。<br>
<img alt="模型压缩的流水线" loading="lazy" src="https://pics.zhouxin.space/202411270940517.webp"></p>
<p>Again，在有着极高压缩率的同时，仍旧能够保持精度不变。<br>
<img alt="模型压缩性能" loading="lazy" src="https://pics.zhouxin.space/202411270943756.webp"></p>
]]></content:encoded>
    </item>
    <item>
      <title>如何在VSCode中“优雅”地配置CMake —— 以PaddlePaddle为例</title>
      <link>https://www.zhouxin.space/notes/%E5%A6%82%E4%BD%95%E5%9C%A8vscode%E4%B8%AD%E4%BC%98%E9%9B%85%E5%9C%B0%E9%85%8D%E7%BD%AEcmake--%E4%BB%A5paddlepaddle%E4%B8%BA%E4%BE%8B/</link>
      <pubDate>Fri, 15 Nov 2024 11:09:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/%E5%A6%82%E4%BD%95%E5%9C%A8vscode%E4%B8%AD%E4%BC%98%E9%9B%85%E5%9C%B0%E9%85%8D%E7%BD%AEcmake--%E4%BB%A5paddlepaddle%E4%B8%BA%E4%BE%8B/</guid>
      <description>&lt;p&gt;通过本文，你将了解如何在 VSCode 中配置 CMake 项目，包括但不限于语法高亮、代码跳转、CMake 配置、构建、测试。&lt;/p&gt;
&lt;h2 id=&#34;环境说明&#34;&gt;环境说明&lt;/h2&gt;
&lt;p&gt;本文使用 WSL Ubuntu 22.04 作为演示环境，VSCode 版本为 &lt;code&gt;1.95.2&lt;/code&gt;，使用项目为 &lt;a href=&#34;https://github.com/PaddlePaddle/paddle&#34;&gt;PaddlePaddle&lt;/a&gt;。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>通过本文，你将了解如何在 VSCode 中配置 CMake 项目，包括但不限于语法高亮、代码跳转、CMake 配置、构建、测试。</p>
<h2 id="环境说明">环境说明</h2>
<p>本文使用 WSL Ubuntu 22.04 作为演示环境，VSCode 版本为 <code>1.95.2</code>，使用项目为 <a href="https://github.com/PaddlePaddle/paddle">PaddlePaddle</a>。</p>
<p>VSCode 中需要安装如下插件：</p>
<ul>
<li><a href="https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd">clangd</a>：配合 clangd server 实现对 C/C++ 的代码高亮、补全、跳转、重构等</li>
<li><a href="https://marketplace.visualstudio.com/items?itemName=ms-vscode.cmake-tools">CMake Tools</a>：为 CMake 项目的提供支持</li>
<li><a href="https://marketplace.visualstudio.com/items?itemName=ms-python.python">Python</a>：为项目提供 Pythjon 支持，包括高亮、跳转、调试等</li>
</ul>
<p>还需要安装如下工具：</p>
<ul>
<li><a href="https://clangd.llvm.org/installation">clangd</a>：clangd sever，为 C/C++ 的代码解析提供支持</li>
</ul>
<p>CMake、编译器、调试器等工具默认可用。</p>
<h2 id="配置-cmake-项目">配置 CMake 项目</h2>
<p>在 VSCode 中安装 CMake Tools 插件后第一次打开 CMake 项目，VSCode 默认会自动进行配置，即默认执行 <code>CMake: Configure</code> 命令。如果检测到多个编译器，会提示用户选择一个。此时 CMake 插件还没有做任何配置，这时候进行 Configure 大概率是不符合用户预期的，我们可以使用 <code>ESC</code> 退出 Configure 过程。<br>
<img alt="编译器选择页面" loading="lazy" src="https://pics.zhouxin.space/202411151536207.webp"></p>
<p>我们首先对 VSCode 插件进行配置。在 VSCode 中打开 <code>settings-Workspace</code>，在 workspace 修改的设置内容将以 <code>.vscode/settings.json</code> 文件的形式保存在项目根文件夹中。通过为每个项目保存不同的配置，可以方便地且“优雅”地在不同项目之间切换。依据下图，在设置中找到 CMake Tools 插件的设置。</p>
<p><img alt="CMake Tools 插件配置页面" loading="lazy" src="https://pics.zhouxin.space/202411151546650.webp"></p>
<p>其中有几项值得关注，可以根据自己需要进行修改：</p>
<ul>
<li>Build Directory：指定 CMake 构建目录路径</li>
<li>Build Environment &amp; Configure Environment：指定配置和构建阶段环境变量</li>
<li>Build Args &amp; Configure Args：指定配置和构建阶段额外命令行参数</li>
<li>Cmake Path：指定 Cmake 可执行文件路径</li>
<li>Generator：指定生成器，例如 Ninja</li>
</ul>
<p>插件配置完成后，在 <code>.vscode/settings.json</code> 文件中就可以看到对应的修改：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.configureArgs&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;-DPY_VERSION=3.12&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;-DWITH_GPU=OFF&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;-DWITH_TESETING=ON&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;-DPYTHON_EXECUTABLE=/home/zhouxin/miniconda3/envs/paddle-dev/bin/python&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">],</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.configureSettings&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;CMAKE_EXPORT_COMPILE_COMMANDS&#34;</span><span class="p">:</span> <span class="kc">true</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.buildDirectory&#34;</span><span class="p">:</span> <span class="s2">&#34;${workspaceFolder}/build_mask&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.automaticReconfigure&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.configureOnOpen&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.configureOnEdit&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;cmake.generator&#34;</span><span class="p">:</span> <span class="s2">&#34;Ninja&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来可以为整个项目指定编译工具链，使用快捷键 <code>Ctirl+Shift+P</code> 唤起 VSCode 命令面板，并搜索 <code>cmake: kits</code> 可以找到 <code>CMake: Edit User-Local CMake Kits</code>，在该文件内，可以配置多个不同的编译工具链，相关说明见：<a href="https://github.com/microsoft/vscode-cmake-tools/blob/main/docs/kits.md">vscode-cmake-tools/docs/kits.md at main · microsoft/vscode-cmake-tools · GitHub</a>。然后使用命令 <code>CMake: Select a Kit</code> 选择为本项目选择一套合适的工具链。</p>
<p>一切就绪，接下来就可以对 CMake 项目进行 Configure 操作。在命令面板找到 <code>CMake Configure</code>，执行之。在 VSCode Output 面板中切换到 CMake，可以看到输出的日志，如果有错误，可以根据错误 Debug。</p>
<p><img alt="CMake 日志输出面板" loading="lazy" src="https://pics.zhouxin.space/202411171337718.webp"></p>
<p>一般来说，在排除 CMakeList 文件本身出错和环境没准备妥当之后，大概率是某些环境变量的问题。可以在 <code>.vscode/settings.json</code> 设置文件中修改某些环境变量值，或者传入某些参数以指定某些工具的路径就可以解决。</p>
<h2 id="code-intelligence">Code Intelligence</h2>
<p>Code Intelligence 指的是一系列语法高亮、代码跳转、自动补全、错误检测等等功能的集合，一言以蔽之，就是让 IDE 理解你的代码。对于没有配置好 Code Intelligence 的项目，随意打开一个文件，可能存在大量头文件找不到的报错，函数调用之间跳转基本也都是失败的，IDE<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> 无法定位到函数源码的位置。</p>
<p>而构建工具，例如 CMake，完全掌握着文件之间的依赖关系 <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>。在 CMake 配置过程中，可以使用参数 <code>-DCMAKE_EXPORT_COMPILE_COMMANDS=1</code> 或者在插件设置中添加如下内容以要求 CMake 在配置过程在构建目录生成包含文件依赖信息的文件 <code>compile_commands.json</code>。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="s2">&#34;cmake.configureSettings&#34;</span><span class="err">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="nt">&#34;CMAKE_EXPORT_COMPILE_COMMANDS&#34;</span><span class="p">:</span> <span class="kc">true</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span><span class="err">,</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>构建工具给出信息之后，还得告诉 clangd 这些“信息”的具体位置。在 <code>.vscode/settings.json</code> 文件中添加如下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="s2">&#34;clangd.arguments&#34;</span><span class="err">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;--compile-commands-dir=${workspaceFolder}/build&#34;</span><span class="p">,</span> <span class="c1">// 指定编译信息所在目录
</span></span></span><span class="line"><span class="cl">    <span class="s2">&#34;-j=20&#34;</span><span class="p">,</span>                                        <span class="c1">// 设置并行任务数为20
</span></span></span><span class="line"><span class="cl">    <span class="s2">&#34;--background-index&#34;</span><span class="p">,</span>                           <span class="c1">// 启用后台索引
</span></span></span><span class="line"><span class="cl">    <span class="s2">&#34;--pch-storage=memory&#34;</span><span class="p">,</span>                         <span class="c1">// 将预编译头存储在内存中
</span></span></span><span class="line"><span class="cl">    <span class="s2">&#34;--limit-results=500&#34;</span><span class="p">,</span>                          <span class="c1">// 限制结果数量为500
</span></span></span><span class="line"><span class="cl">    <span class="s2">&#34;--log=info&#34;</span>                                    <span class="c1">// 设置日志级别为info
</span></span></span><span class="line"><span class="cl"><span class="p">]</span><span class="err">,</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>其中第一行指定 <code>compile_commands.json</code> 文件所在目录，默认为 CMake 的构建目录。其余为推荐使用的一些其它配置，加速大型项目的解析。</p>
<p>接下来重新执行 <code>CMake: Configure</code>，并在配置完成后重启 clangd 服务器 <code>clang: restart language server</code>。经过此番折腾，IDE 已经能够对大部分头文件进行解析，并正确实现文件中的跳转。</p>
<p>然而，在 Paddle 的项目中，仍有许多函数无法被正确解析。这是由于在 <code>Paddle/third_party</code> 文件中有众多第三方工具的源码。这些工具将在 CMake build 过程中被安装到构建目录之内，这些目前仍无法解析的函数依赖这些第三方工具。使用命令 <code>CMake: Build</code> 构建整个项目，完成之后所有函数就能够被正常解析了！</p>
<h2 id="debug">Debug</h2>
<p>Debug 取决于具体的项目，在 VSCode 使用 <code>.vscode/launch.json</code> 对 Debug 进行配置。关于对 Python 和 CUDA/C/C++ 代码进行联合调试的内容，可以查看另一篇文章：<a href="https://www.zhouxin.space/notes/joint-debgugging-of-cuda-and-python-in-vscode/">在VSCode中对CUDA和Python代码进行联合调试 | 周鑫的个人博客</a>。</p>
<h2 id="测试">测试</h2>
<p>在 VSCode 中，使用 Test Explorer 对测试进行管理和配置 <sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup>，VSCode 本身不提供特定语言的测试配置，而是以插件的形式扩展特定语言的测试支持。在插件中搜索 <code>@category:&quot;testing&quot;</code> 可以查看所有测试插件。<code>Python</code> 和 <code>CMake Tools</code> 插件似乎自带对 Python Test 和 CTest 的支持，不需要额外安装测试插件。</p>
<p>在 VSCode 的测试面板，可以看到所有测试项目。下图展示了来自 CMake 的 CTest 和来自 Python 的 Python Tests 项目，Paddle 的那个测试还没确定怎么来的。Anyway，能跑就行😎。</p>
<p><img alt="测试面板" loading="lazy" src="https://pics.zhouxin.space/202411171438048.webp"></p>
<h1 id="参考">参考</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>一般不认为 VSCode 是一个集成开发环境，但是 VSCode 配合一系列插件说是 IDE 也不为过，并且是具有高度可定制能力的 IDE&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p><a href="https://clang.llvm.org/docs/JSONCompilationDatabase.html">JSON Compilation Database Format Specification — Clang 20.0.0git documentation</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p><a href="https://code.visualstudio.com/docs/editor/testing">Testing in Visual Studio Code</a>&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.824 Distributed Systems Spring 2023 第三讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-3rd-lecture-of-mit-6-824-distributed-system-spring-2023/</link>
      <pubDate>Wed, 13 Nov 2024 10:40:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-3rd-lecture-of-mit-6-824-distributed-system-spring-2023/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;MIT 6.824 Distributed Systems 第三讲学习笔记，简单介绍了存储系统和一致性，主要介绍了 GFS 中的文件读写流程。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id=&#34;存储系统概述&#34;&gt;存储系统概述&lt;/h1&gt;
&lt;p&gt;存储系统在分布式系统中相当重要：如果能够建立一个可靠的存储系统，可以讲其它应用构建为无状态的，而在存储系统中持久存储状态，这能够大量简化应用设计。这种情况下，应用即使崩溃也可以迅速重启，并从存储系统中读取状态进行恢复。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>MIT 6.824 Distributed Systems 第三讲学习笔记，简单介绍了存储系统和一致性，主要介绍了 GFS 中的文件读写流程。</p>
</blockquote>
<h1 id="存储系统概述">存储系统概述</h1>
<p>存储系统在分布式系统中相当重要：如果能够建立一个可靠的存储系统，可以讲其它应用构建为无状态的，而在存储系统中持久存储状态，这能够大量简化应用设计。这种情况下，应用即使崩溃也可以迅速重启，并从存储系统中读取状态进行恢复。</p>
<p>这也意味着，存储系统本身必须具备良好的容错能力，存储系统自身设计相当困难：</p>
<ul>
<li>高性能：需要跨服务器共享数据</li>
<li>多服务器：这意味着某个服务器挂掉是常态</li>
<li>容错设计：必须有 replication 冗余备份</li>
<li>冗余备份：可能导致潜在的不一致性</li>
<li>强一致性：需要一致性协议和网络通信，进而降低性能</li>
</ul>
<h1 id="ideal-consistensy">Ideal consistensy</h1>
<p>理想的一致性系统应该表现得像一个单机系统。</p>
<p>在一致性系统中，涉及到并发问题。在分布式系统会出现一些额外的并发问题。例如，当 A 和 B 依次分别发送像 X 写入 1 和 2 的请求，C 和 D 再依次发送读取 X 的请求。CD 读到的数据组合可能为 (1, 1)、(1, 2)、(2, 2)，而不可能出现 (2, 1)。</p>
<p>而在分布式系统中，同样是上述读写次序，A 的写入请求可能只背服务器 S1 接收到、B 的写入请求只被 S2 接受接受到，而 C 的读取请求被 S2 处理，D 的读取请求被 S1 处理，那么 CD 读到的数据组合为 (2, 1)。这里我们需要一种同步协议来协调读者和写者。本课程的后半段将花费大量篇幅介绍不同的同步协议，他们基本都是在容错和一致性之间的折中。</p>
<h1 id="gfs-google-file-system">GFS: Google File System</h1>
<p>GFS 是本课程第一个案例分析，其是一个以高性能作为设计目标的分布式文件系统，其具备副本、容错、一致性。</p>
<h2 id="特点">特点</h2>
<p>这篇论文发表于 21 世纪初，彼时学界对于分布式系统已经有较为成熟的理论体系，但缺少实际可用的工业界产品。GFS 作为一个成功的分布式文件系统，其实际上并非是一个标准的学术界研究的分布式系统：</p>
<ul>
<li>单一 master</li>
<li>可能存在不一致性</li>
</ul>
<p>GFS 的特点：</p>
<ul>
<li>Big：具有大的数据集</li>
<li>Fast：自动将文件分片到多个服务器上</li>
<li>global：所有应用都能看到相同的文件</li>
<li>fault tolerance：自动容错和恢复机制</li>
</ul>
<h2 id="整体设计">整体设计</h2>
<p>GFS 架构图如下所示：<br>
<img alt="GFS 架构图" loading="lazy" src="https://pics.zhouxin.space/202411131245904.webp"><br>
每个文件被分为多个 chunk，每个 chunk 不是很小，约为 64MB。应用告知 master 需要访问的文件名和 chunk 好，master 返回对应的句柄和存储位置，应用再根据存储位置找到对应的 chunk server，发送访存请求。</p>
<h2 id="master">Master</h2>
<p>Master 维护一组从文件名到 chunk 句柄数组的映射表。对于每一个 chunk 句柄，Master 维护其版本号、持有该句柄的服务器列表、这些服务器的主从次序、服务器的租约长度（lease time）信息。此外，Master 还要负责日志和检查点保存。Master 在响应请求之前，首先将其写入日志，这意味着即便 Master 挂了，也可以从日志中恢复重建，并响应请求。</p>
<p>文件名到 chunk 句柄数组的映射表应当作为持久状态被 Master 定期保存到磁盘上，而与 chunk 相关的服务器信息则不必要，该信息可以在 Master 重启时主动要求其它服务器汇报其持有的 chunk 信息从而重建。chunk 版本号也需要作为持久状态，因为 Master 需要明确了解整个系统每个句柄的最新版本号，而不是由其它服务器汇报，以应对那些真正持有最新版本号的服务器也许也一起挂掉的情况。</p>
<h2 id="读写流程">读写流程</h2>
<ul>
<li>读文件流程</li>
</ul>
<ol>
<li>客户端将文件名和偏移量发送给 Master。</li>
<li>Master 告知客户端句柄、持有该句柄的服务器列表、版本号。</li>
<li>客户端缓存上述信息。这一步可以减少 Master 的压力、网络流量和客户端自身访问文件的延迟。</li>
<li>客户端按照由近及远的顺序依次尝试从服务器列表中获取文件。</li>
<li>chunk 服务器检查版本号，通过则将文件数据发送给客户端。</li>
</ol>
<ul>
<li>写入流程之 append 操作</li>
</ul>
<ol>
<li>客户端告知 Master 文件名。</li>
<li>Master 根据文件名找到对应 chunk 句柄。Master 根据 chunk 句柄找到持有该句柄的 chunk 服务器列表。 如果没有主 chunk 服务器，Master 选择一个作为主 chunk 服务器，该主 chunk 服务器被授权一段时间的 lease 即租约，时间内其可以对该句柄进行修改。同时 Master 增加版本号，该版本号信息将分发给所有持有该句柄的服务器，句柄服务器应当将版本号持久保存在磁盘上。</li>
<li>Master 告知客户端文件主从 chunk 服务器列表和版本号。</li>
<li>客户端将文件由近及远尝试发送给 chunk 服务器。</li>
<li>如果是从 chunk 服务器接收到了写入内容，其再转发给主 chunk 服务器，主服务器转发给从服务器。如果一切正常，写入完毕后主服务器告知客户端成功写入。如果错误，则告知客户端错误原因。</li>
<li>如果客户端接受到错误信息。其会自动重试，直至成功。以确保 at-least-once RPC 语义。</li>
<li>在第二次重试中，主 chunk 服务器将不会在原有磁盘 offset 上写入，该位置已经存放了第一次写入的内容，其会向后寻找新位置。因此，磁盘上可能存有重复的记录。</li>
</ol>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML 第四讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-4th-lecture/</link>
      <pubDate>Mon, 11 Nov 2024 13:52:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-4th-lecture/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;本讲延续上一讲继续介绍了两种确定剪枝比例的算法：灵敏度分析和强化学习。此外还介绍了为稀疏网络提供支持的硬件加速器，包括 EIE、NVIDIA Tensor Core、TorchSparse 等。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>本讲延续上一讲继续介绍了两种确定剪枝比例的算法：灵敏度分析和强化学习。此外还介绍了为稀疏网络提供支持的硬件加速器，包括 EIE、NVIDIA Tensor Core、TorchSparse 等。</p>
</blockquote>
<p>如无另外说明，图片均引用自 <a href="https://efficientml.ai">EfficientML</a> 课程幻灯片。</p>
<h1 id="lecture-4-pruning-and-sparsity-剪枝和稀疏性">Lecture 4: Pruning and sparsity 剪枝和稀疏性</h1>
<h1 id="剪枝率">剪枝率</h1>
<p>如下图所示，研究指出对每一层采取不均匀的剪枝比例的效果显著优于均匀剪枝，问题在于如何确定每一层的剪枝比例。<br>
<img alt="不均匀剪枝比例显著优于均匀剪枝" loading="lazy" src="https://pics.zhouxin.space/202411101855239.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="灵敏度分析">灵敏度分析</h2>
<p>通过对每一层灵敏度进行分析，即对每一层按照不同剪枝率进行剪枝，观察其对最后精度的镜像程度来判断精度对于每一层的敏感程度。如下图所示，可以发现精度对于 L1 层最不敏感，对于 L0 层最敏感。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202411101920552.png?x-oss-process=image/quality,q_90/format,webp"><br>
此外，当我们实现剪枝算法时，灵敏度曲线也可以用于检查算法是否实现有误：前期几乎不掉点，后期显著掉点。</p>
<p>进行灵敏度分析之后，可以确定一个能够接受的掉点阈值，根据此阈值确定每一层的剪枝比例。</p>
<p>这里隐含了一个假设：这些层之间是相互独立的，即我们没有考虑层一层之间的交互作用。</p>
<h2 id="自动剪枝">自动剪枝</h2>
<p>之前聊过的剪枝方案都是由人工来确定剪枝策略的，但是这种方案不够优雅，并且不具备可扩展性 (scalability)。</p>
<p>这里介绍一种基于强化学习的剪枝比例确定方案，其动机是我们训练一个模型，输入为每一层的信息，输入为相对应的剪枝比例。</p>
<p>笔者缺乏强化学习领域相关知识，以下翻译可能不太恰当。</p>
<p>我们的模型 setup 为：</p>
<ul>
<li>状态（输入）
<ul>
<li>描述每一层的特征，包括层序号、通道数、卷积核大小、FLOPs&hellip;</li>
</ul>
</li>
<li>Action （应该是模型输出？）
<ul>
<li>0 到 1 之间的数，表示剪枝率</li>
</ul>
</li>
<li>智能体
<ul>
<li>DDPG agent（不懂）</li>
</ul>
</li>
<li>奖励 （目标函数？）
<ul>
<li>-Error 如果满足约束</li>
<li>-inf 如果不满足</li>
</ul>
</li>
</ul>
<p>如下图所示，相比人类耗时地手工调优，基于强化学习的方法耗时更短并且性能更优。<br>
<img alt="手工 v.s. 自动剪枝" loading="lazy" src="https://pics.zhouxin.space/202411112329452.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="微调剪枝后的神经网络">微调剪枝后的神经网络</h2>
<p>前文提到，剪枝完成后需要对网络进行微调，以恢复网络性能或者取得更好的效果。</p>
<ul>
<li>学习率：一般设置为原先学习率的十分之一到百分之一</li>
<li>采用迭代剪枝：第二讲提到对迭代进行剪枝 - 微调的效果显著优于一次性的剪枝 - 微调操作</li>
<li>正则化：在微调过程中需要使用 L1 或者 L2 正则化</li>
</ul>
<h1 id="为稀疏性提供支持的系统">为稀疏性提供支持的系统</h1>
<h2 id="eie-efficient-inference-engine">EIE: Efficient Inference Engine</h2>
<p>EIE 使用了三种优化方式：<br>
<img alt="EIE 优化技术" loading="lazy" src="https://pics.zhouxin.space/20241112110905.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>权重稀疏化：能够节省 90% 的计算资源和 80% 的内存资源，内存节省略少是由于保存稀疏信息带来的额外开销</li>
<li>激活层稀疏化：能够节省 66% 的计算资源</li>
<li>量化</li>
</ul>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20241112112715.png?x-oss-process=image/quality,q_90/format,webp"><br>
对于如上的稀疏矩阵 - 向量乘法 (<strong>Sp</strong>arse <strong>M</strong>atrix <strong>V</strong>ector Multiplication, SpMV)，上半图表示 SpMV 的逻辑形式，权重矩阵被染为四种颜色，每个颜色代表一个处理单元 PE 即 processing element，这些 PEs 将并行执行。下半图表示第一个 PE 中保存的权重信息，PE0 中的稀疏矩阵采取的是 CSC 格式 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>，即在内存中 Virtual Weight 只保存 5 个非零元素，以及这五个非零元素与前一个非零元素之间的距离 Relative Index【这里是列优先的距离】，和指向列的行指针 Column Pointer。<code>Pointer[i]</code> 表示从 Virtual Weight 中第 <code>Pointer[i]</code> 开始是原稀疏矩阵中第 <code>i</code> 行的元素。</p>
<p>注意，上图并非标准 CSC 格式，在标准格式中，采用的 Row Index 来直接记录行号，而非使用 relative Index 记录相邻元素之间的距离。</p>
<p>在计算过程中，每个 PE 对输入向量 $\vec{a}$ 逐元素遍历，直接跳过零元；对于非零元素，将其广播到所有非零元 PE 中并计算想用结果。</p>
<p>整个 PE 单元的微架构图如下所示：<br>
<img alt="PE 单元微架构" loading="lazy" src="https://pics.zhouxin.space/202411121605329.png?x-oss-process=image/quality,q_90/format,webp"><br>
首先是 Activation Queue 用于存储所有非零激活层元素，然后根据非零元的 Activation Index 来获取 Weight Column Pointer，确定与激活层元素相对应的权重元素的起始和结束索引，接着获取权重元素，并对其进行解码（解码相关将在后文说明），最后将结果暂存累加，并经过 ReLU 层后输出。</p>
<p>得益于激活层和权重矩阵的稀疏性，可以将其放在 SRAM 中加速存取。</p>
<h2 id="nvidia-tensor-core">NVIDIA Tensor Core</h2>
<p>第三讲提到的 M:N 的稀疏策略，即连续 N 个元素中必定有 M 个零元素。例如，2:4 的稀疏矩阵可以表示为：<br>
<img alt="2:4的稀疏矩阵表示格式" loading="lazy" src="https://pics.zhouxin.space/202411121650058.png?x-oss-process=image/quality,q_90/format,webp"><br>
对于 2:4 的且形状为 RxC 的稀疏矩阵，在内存中可以只保存所有非零元，即 R x C/2 的矩阵。还需要保存额外的元数据来确定非零元在原始矩阵中的位置，每个非零元可能来自其原来小组的 0~3 中的某个位置，因此每个非零元需要 4 bit 信息来记录其在小组中的原始位置。<br>
<img alt="Tensor core 稀疏矩乘加速实现" loading="lazy" src="https://pics.zhouxin.space/202411121658663.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>上图右侧展示了 Tensor core 中稀疏矩乘的实现，以 A 的第一行为例，A 的第一行需要与 B 的第一列做内积，使用 A 在稀疏化过程中保存的 indices 信息，就可以对 B 的第一列进行筛选，只索引出 A 中非零元对应位置的元素，并进行乘法和累加操作。</p>
<h2 id="torchsparse">TorchSparse</h2>
<p>这里介绍一种稀疏卷积：如下图左所示，当对稀疏矩阵进行传统卷积操作时，经过卷积后激活层的稀疏度会下降，非零元素会逐渐扩散到周围的零元中；而在右侧的稀疏卷积中，要求保持输入中的稀疏性，即卷积后的零元仍旧是零元。<br>
<img alt="洗漱卷积示意图" loading="lazy" src="https://pics.zhouxin.space/202411121833124.png?x-oss-process=image/quality,q_90/format,webp"><br>
在稀疏卷积的实现中，对于单次卷积，只有非零元参加与卷积核的矩乘计算。这里介绍的实现方式为：建立输入元素 - 输出元素 - 卷积核权重的三元组，按照权重对三元组进行排序（实质上找到所有需要与该权重做乘法的元素）。<br>
<img alt="稀疏卷积实现" loading="lazy" src="https://pics.zhouxin.space/202411121839341.png?x-oss-process=image/quality,q_90/format,webp"><br>
分组结束后，就可以采用一种自适应分组算法将上述标量乘法 - 加法转换为矩阵乘法（这一转换过程我直觉上认为类似于 img2col 技术将卷积转换为矩乘）。当然，这一过程存在大量开销。<br>
<img alt="稀疏卷积实现总体流程" loading="lazy" src="https://pics.zhouxin.space/202411121844924.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>计算规律性 (computation regularity) 与计算开销之间的折中<br>
为了将多个向量乘法聚合为一个矩阵乘法，我们不得不对某些较短的向量进行补零操作，这使得计算过程呈现出更好的规律性，但是带来了额外的计算开销。如下图所示，这里采取的折中手段是对每个计算 batch 进行动态分组<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202411121955937.webp?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<h2 id="torchsparse-1">TorchSparse++</h2>
<p>如下图左所示，邻接的行在一组中进行计算，以第一行为例，为了计算出 B1，在计算 B0 中必须将 $W_{0,-1}$ 和 $W_{1,0}$ 也参与计算，这就是前文提到的额外计算开销。在 TorchSparse++ 中，提出了一种行重排算法，经过重新分组后可以减少此类冗余计算。此外，还可以对权重进行分割，以实现更加细粒度的冗余优化。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202411122002814.webp"></p>
<h2 id="pointacc-稀疏卷积的硬件加速器">PointAcc: 稀疏卷积的硬件加速器</h2>
<p>这里介绍了前面一直在用的“输入元素 - 输出元素 - 卷积核权重”三元组的构建算法，该算法可以使用硬件加速，但没太理解算法原理，似乎也不是重点，略。</p>
<h1 id="参考">参考</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://docs.nvidia.com/nvpl/_static/sparse/storage_format/sparse_matrix.html#compressed-sparse-column-csc">Sparse Matrix Formats — NVPL SPARSE documentation</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.824 Distributed Systems Spring 2023 第二讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-2nd-lecture-of-mit-6-824-distributed-system-spring-2023/</link>
      <pubDate>Sun, 10 Nov 2024 09:10:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-2nd-lecture-of-mit-6-824-distributed-system-spring-2023/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;MIT 6.824 Distributed Systems 第二讲学习笔记，包括 Go 语言和并发编程的简单介绍，以及对 RPC 和故障情况下的 RPC 语义的说明。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id=&#34;lecture-2-rpc-and-threads&#34;&gt;Lecture 2: RPC and Threads&lt;/h1&gt;
&lt;h2 id=&#34;为什么选择-go-语言&#34;&gt;为什么选择 Go 语言&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;对于线程和 RPC 的良好支持&lt;/li&gt;
&lt;li&gt;拥有垃圾回收机制&lt;/li&gt;
&lt;li&gt;类型安全&lt;/li&gt;
&lt;li&gt;简单&lt;/li&gt;
&lt;li&gt;可以编译&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;并行编程入门&#34;&gt;并行编程入门&lt;/h2&gt;
&lt;p&gt;在 Go 中，线程被称为 go routine，线程具有独立的 PC、栈和寄存器。线程之间共享内存地址，处于同一片内存空间内。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>MIT 6.824 Distributed Systems 第二讲学习笔记，包括 Go 语言和并发编程的简单介绍，以及对 RPC 和故障情况下的 RPC 语义的说明。</p>
</blockquote>
<h1 id="lecture-2-rpc-and-threads">Lecture 2: RPC and Threads</h1>
<h2 id="为什么选择-go-语言">为什么选择 Go 语言</h2>
<ul>
<li>对于线程和 RPC 的良好支持</li>
<li>拥有垃圾回收机制</li>
<li>类型安全</li>
<li>简单</li>
<li>可以编译</li>
</ul>
<h2 id="并行编程入门">并行编程入门</h2>
<p>在 Go 中，线程被称为 go routine，线程具有独立的 PC、栈和寄存器。线程之间共享内存地址，处于同一片内存空间内。</p>
<p>线程还可以被视为由 Go 运行时提供 start、exit、stop、resume 操作的一组结构。</p>
<h2 id="为什么要使用线程">为什么要使用线程</h2>
<p>在本课程中，使用线程是基于并发的考虑：</p>
<ul>
<li>IO 并发：当一个线程发起例如网络 IO 请求时，在等待回复时可以将该线程阻塞，并调度其它线程；</li>
<li>多核并行：在多核处理器上，可以在不同的核心上运行不同的线程，显著提升吞吐量；</li>
<li>方便：在实验中可能需要周期性地执行某一任务，go routine 可以轻松地实现这一功能。</li>
</ul>
<h2 id="线程中的挑战">线程中的挑战</h2>
<ul>
<li>条件竞争：程序执行结果取决于线程之间具体的执行顺序
<ul>
<li>规避方案一：线程之间不要共享变量</li>
<li>规避方案二：使用锁</li>
<li>Go 中具有一个竞争检测器，可以检测出许多潜在的竞争行为</li>
</ul>
</li>
<li>协调（似乎是指同步）
<ul>
<li>通道</li>
<li>条件变量</li>
</ul>
</li>
<li>死锁</li>
</ul>
<h2 id="go-的解决方案">Go 的解决方案</h2>
<p>针对上述挑战，Go 中提供了对应的解决方案：</p>
<ul>
<li>通道（不共享内存）</li>
<li>锁和条件变量</li>
</ul>
<h2 id="rpc-远程过程调用">RPC 远程过程调用</h2>
<ul>
<li>目标
<ul>
<li>将远程调用透明化，使之表现得像在进行本地调用。所谓过程调用就是指调用可执行代码，例如函数、方法、子程序等等。</li>
</ul>
</li>
<li>实现
<ul>
<li>在本地将有一个与远程调用对应的函数桩 stub，stub 记录了远程调用的必要信息，例如函数签名，其通过网络将调用信息发送给服务器。</li>
<li>在服务器也有一个类似的函数桩，负责对通过网络接收到的调用信息进行解码，并调用服务器端函数实现，返回计算结果。</li>
<li>本地 stub 接收到服务器返回的结果后，进行解码并返回给本地调用者。</li>
<li>上述过程全部由编译器和运行时负责实现，对于程序员而言是透明的。</li>
</ul>
</li>
</ul>
<h2 id="故障情况下的-rpc-语义行为">故障情况下的 RPC 语义/行为</h2>
<p>故障情况下 RPC 可以有多种不多的行为：</p>
<ul>
<li>at least once
<ul>
<li>如果服务器挂了，客户端将不断试直至至少执行一次</li>
</ul>
</li>
<li>at most once
<ul>
<li>最多执行一次</li>
<li>服务器需要确保不对重复的请求多次执行</li>
<li>Go 采用的模式</li>
<li>使用 TCP 协议来实现最多发送一次 RPC</li>
</ul>
</li>
<li>exact once
<ul>
<li>恰好一次</li>
<li>这比较难以实现</li>
</ul>
</li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML 第三讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-3rd-lecture/</link>
      <pubDate>Sat, 09 Nov 2024 14:01:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-3rd-lecture/</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;MIT 6.5940 EfficientML 第三讲学习笔记，主要介绍剪枝的定义、效果和粗细程度，并详细介绍了多种剪枝标准。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h1 id=&#34;lecture-3-pruning-and-sparsity-剪枝和稀疏性&#34;&gt;Lecture 3: Pruning and sparsity 剪枝和稀疏性&lt;/h1&gt;
&lt;h2 id=&#34;剪枝的动机&#34;&gt;剪枝的动机&lt;/h2&gt;
&lt;p&gt;在上一讲提到，内存操作的代价相当昂贵，因此为了加速模型的运行，一个思路就是减少模型中一切内存的占用，包括减小模型大小、减小激活层大小和数量。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<blockquote>
<p>MIT 6.5940 EfficientML 第三讲学习笔记，主要介绍剪枝的定义、效果和粗细程度，并详细介绍了多种剪枝标准。</p>
</blockquote>
<h1 id="lecture-3-pruning-and-sparsity-剪枝和稀疏性">Lecture 3: Pruning and sparsity 剪枝和稀疏性</h1>
<h2 id="剪枝的动机">剪枝的动机</h2>
<p>在上一讲提到，内存操作的代价相当昂贵，因此为了加速模型的运行，一个思路就是减少模型中一切内存的占用，包括减小模型大小、减小激活层大小和数量。</p>
<h2 id="剪枝定义">剪枝定义</h2>
<p>剪枝的数学定义如下所示：<br>
<img alt="剪枝的数学定义" loading="lazy" src="https://pics.zhouxin.space/202411091410114.png?x-oss-process=image/quality,q_90/format,webp"><br>
具体来说，剪枝指的是移除神经网络中的某一些参数，使得神经网络成为一个相对稀疏网络。</p>
<h2 id="剪枝效果">剪枝效果</h2>
<p>如下图所示，经过剪枝后的模型，其性能损失并不显著；相反，如果在剪枝后对模型进行微调，甚至迭代剪枝和微调的步骤，可以实现仅使用几分之一的参数量达到相同或者更优的准确率。<br>
<img alt="剪枝效果图" loading="lazy" src="https://pics.zhouxin.space/202411091418290.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>这一节的效果多少有点震撼到我：经过剪枝后，模型的参数量可以减少 90%，并且不掉点，甚至性能更优！</p>
<h2 id="剪枝细粒程度">剪枝细粒程度</h2>
<p>根据剪枝实施的细粒程度，可以对其进行分类。</p>
<ul>
<li>
<p>细粒度剪枝<br>
细粒度剪枝允许对任意单个权重进行剪枝，其优点是剪枝程度更高，确定是参数张量作为稀疏矩阵，在硬件上的加速更加难以实现。</p>
</li>
<li>
<p>粗粒度剪枝<br>
粗粒度剪枝只能都对参数矩阵中的某一行进行全部剪枝，缺点是不够灵活，优点是经过剪枝后的参数仍旧是一个非稀疏矩阵（剪枝后的矩阵变得更小），易于加速。<br>
<img alt="全连接层剪枝示意图" loading="lazy" src="https://pics.zhouxin.space/202411091436763.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
</ul>
<p>前面举的是全连接层的例子，接下来讨论卷积操作的剪枝。卷积层的参数张量形状为 $[c_o, c_i, k_h, k_w]$。由于其具有四个维度，因此其剪枝的细粒程度种类多得多：</p>
<p><img alt="卷积核可视化记号示意图" loading="lazy" src="https://pics.zhouxin.space/202411091440974.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p><img alt="卷积核不同细粒程度示意图" loading="lazy" src="https://pics.zhouxin.space/202411091439970.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>第一种方式是剪枝最细的，往往可以实现数量级的剪枝，然而其一般需要在特定硬件上才能取得比较好的推理速度，在 GPU 上设计算法比较困难。</p>
<p>第二种方式是基于模式的，采用固定的模式对卷积核进行剪枝。一种典型的模式是 N:M，即在连续 M 个参数中选择 N 的进行剪枝。</p>
<p>最后一种方式是对整个通道进行剪枝，其优点在于经过剪枝后仍旧是通用矩乘，缺点是剪枝率很比较低。</p>
<h2 id="剪枝标准">剪枝标准</h2>
<ul>
<li>
<p>基于大小的剪枝<br>
一种符合直觉、并且效果很好的剪枝方式是根据参数的绝对值来决定其重要性，剪掉那些绝对值很小的参数。如果要进行按行剪枝，则每一行的权重对于这一行向量的 L1 Norm，即这一行向量的绝对值的和。</p>
</li>
<li>
<p>基于缩放参数的剪枝<br>
例如在卷积的每一层中，有多个通道，可以给每个通道配置一个可学习的参数，作为该通道的放缩参数。通过训练来确定每个通道放缩参数，即该通道的重要性，并进行剪枝。</p>
</li>
<li>
<p>基于二阶的剪枝<br>
我们可以使用泰勒展开来表示经过剪枝后模型的误差：</p>
</li>
</ul>


<div>$$

\delta L = L(\mathbf{x}; \mathbf{W}) - L(\mathbf{x}; \mathbf{W}_\rho = \mathbf{W} - \delta\mathbf{W}) = \sum_i g_i \delta w_i &#43; \frac{1}{2} \sum_i h_{ii} \delta w_i^2 &#43; \frac{1}{2} \sum_{i\neq j} h_{ij} \delta w_i \delta w_j &#43; O(\|\delta\mathbf{W}\|^3)

$$</div>

<p>其中，$g_i = \frac{\partial L}{\partial w_i}$，$h_{i,j} = \frac{\partial ^2 L}{\partial w_i \partial w_j}$。</p>
<p>基于二阶的剪枝方法假设：<br>
损失函数是近似二次的，因此高阶误差项可以忽视；<br>
神经网络在训练过程中已经收敛，因此 L 对 w 的一阶导为 0；<br>
对不同参数的剪枝操作引发的误差是彼此独立的，因此 $\frac{1}{2} \sum_{i\neq j} h_{ij} \delta w_i \delta w_j$ 为 0。</p>
<p>因此，剪枝后模型的误差为：</p>


<div>$$

\delta L = L(\mathbf{x}; \mathbf{W}) - L(\mathbf{x}; \mathbf{W}_\rho) \approx \frac{1}{2} \sum_i h_{ii} \delta w_i^2

$$</div>

<p>为了最小化剪枝后的误差，应当保留权重更大的参数，因此重要性表示为：</p>


<div>$$

\text{importance}_{w_i} = \frac{1}{2}h_{ii}w_{i}

$$</div>

<p>其中，$h_{ii}$ 是 Hessian 矩阵。</p>
<ul>
<li>
<p>对激活层进行剪枝<br>
对激活层进行剪枝其本质上就是粗粒度的权重剪枝。如下图所示，如果我们需要移除激活层的某个节点，其进行的操作就是在 FC 网络中移除权重举证的某一行，或者在卷积中移除某个通道。<br>
<img alt="权重剪枝与激活层剪枝之间的关系" loading="lazy" src="https://pics.zhouxin.space/202411091608696.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>基于 0 概率的剪枝<br>
ReLU 层的输出的激活层有概率为零，通过统计 batch 中每个位置的输出为 0 的频率，可以对 0 概率高的位置进行剪枝。</p>
</li>
</ul>
<p>注意，这里是对激活层进行剪枝，而非之前提到的直接对参数进行剪枝。我的理解是，对激活层进行剪枝相当于直接对这激活层对应的两层网络的参数进行剪枝。</p>
<ul>
<li>基于回归的剪枝<br>
如果直接评估整个模型剪枝前后的误差，这一代价可能很高。基于回归的剪枝对网络逐层评估和剪枝。</li>
</ul>
<p>下图展示了对全连接层输入通道的剪枝算法示意图。首先将矩乘结果视为多个通道外积结果的和，为每个通道设置一个缩放系数 $\beta_{c}$，缩放系数越接近 0 说明该通道越不重要。</p>
<p><img alt="基于回归的剪枝算法" loading="lazy" src="https://pics.zhouxin.space/202411101545235.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>优化算法为：首先固定权重，对 $\beta$ 进行优化，对系数小的参数进行剪枝；然后固定 $\beta$，对 $W$ 进行优化。还可以对上述过程进行重复和迭代。</p>
]]></content:encoded>
    </item>
    <item>
      <title>MIT 6.5940 EfficientML 第二讲学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-mit-efficientml-2nd-lecture/</link>
      <pubDate>Tue, 05 Nov 2024 22:33:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-mit-efficientml-2nd-lecture/</guid>
      <description>&lt;p&gt;如无另外说明，本文图片截取自 &lt;a href=&#34;efficient.ml&#34;&gt;EfficientML&lt;/a&gt; 课程幻灯片。&lt;/p&gt;
&lt;h1 id=&#34;lecture-2-basics-of-neural-networks-神经网络基础&#34;&gt;Lecture 2: Basics of neural networks 神经网络基础&lt;/h1&gt;
&lt;h2 id=&#34;神经网络&#34;&gt;神经网络&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;基本术语&lt;br&gt;
如下图所示，我们使用术语 Synapses（突触？）、权重、参数来指代网络中的参数，使用术语神经元、特征、激活层来指代网络中每一层的计算结果。&lt;br&gt;
&lt;img alt=&#34;三层神经网络示意图&#34; loading=&#34;lazy&#34; src=&#34;https://pics.zhouxin.space/202411060903228.png?x-oss-process=image/quality,q_90/format,webp&#34;&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;模型的宽度指的是隐藏层的维度，对于相同的参数量，宽而浅的模型相比窄而深的模型计算效率更高，因为其核函数调用次数更少，并且能够充分进行并行计算。然而后者在准确率上往往表现得更好，这需要进行折中。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>如无另外说明，本文图片截取自 <a href="efficient.ml">EfficientML</a> 课程幻灯片。</p>
<h1 id="lecture-2-basics-of-neural-networks-神经网络基础">Lecture 2: Basics of neural networks 神经网络基础</h1>
<h2 id="神经网络">神经网络</h2>
<ul>
<li>基本术语<br>
如下图所示，我们使用术语 Synapses（突触？）、权重、参数来指代网络中的参数，使用术语神经元、特征、激活层来指代网络中每一层的计算结果。<br>
<img alt="三层神经网络示意图" loading="lazy" src="https://pics.zhouxin.space/202411060903228.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<p>模型的宽度指的是隐藏层的维度，对于相同的参数量，宽而浅的模型相比窄而深的模型计算效率更高，因为其核函数调用次数更少，并且能够充分进行并行计算。然而后者在准确率上往往表现得更好，这需要进行折中。</p>
<ul>
<li>
<p>全连接层<br>
全连接层是对输入进行加权求和并加上偏执项，如下所示：<br>
<img alt="全连接层示意图" loading="lazy" src="https://pics.zhouxin.space/202411060913049.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>2D 卷积<br>
<img alt="2D卷积示意图" loading="lazy" src="https://pics.zhouxin.space/202411060918066.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
</ul>
<p>基本术语这一节略，大多是 DL 的入门知识。</p>
<h2 id="神经网络效率的评价指标">神经网络效率的评价指标</h2>
<ul>
<li>latency 延迟<br>
神经网络完成某一特定任务需要的时间。延迟主要取决于计算时间和内存时间，即：</li>
</ul>


<div>$$

\begin{align*}
\text{latency} &amp;\approx \max(T_{\text{computation}}, T_{\text{memory}})\\
T_{\text{compucation}} &amp;\approx \frac{\text{number of ops in model}}{\text{number of ops that processor can process per second}}\\
T_{\text{mem}} &amp;\approx T_{\text{data movement of activations}} &#43; T_{\text{data movement of weights}}\\
T_{\text{data movement of activations}} &amp;\approx \frac{\text{input and output actication size}}{\text{mem bandwith of processor}}\\
T_{\text{data movement of weights}} &amp;\approx \frac{\text{model size}}{\text{mem bandwith of processor}}
\end{align*}

$$</div>

<p>内存操作消耗的资源比计算多的多的多，如下图所示 32 位的访存操作消耗的能量是常见其它计算操作的数百倍。<br>
<img alt="不同32比特操作消耗能量对比图" loading="lazy" src="https://pics.zhouxin.space/202411081715363.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>throuthput 吞吐量<br>
单位时间内神经网络能够处理的数据量。</li>
</ul>
<p>由于 batch 的存在，延迟很高的模型可能能够同时处理多个 batch，因此高延迟≠低吞吐量。在移动设备上更关心延迟，而在高性能设备上更关心吞吐量。</p>
<ul>
<li>number of parameters 参数量<br>
不同层的参数量计算公式如下图所示：<br>
<img alt="不同层参数量计算公式" loading="lazy" src="https://pics.zhouxin.space/202411081832724.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>model size 模型大小<br>
模型大小取决于参数量与参数字长，即：</li>
</ul>


<div>$$

\text{model size} = \text{number of parameters} \times \text{bit width}

$$</div>

<ul>
<li>
<p>total / peak number of activations 激活层总量和峰值大小<br>
在模型推理中，通常瓶颈在于激活层大小而非参数量，如下图所示，相同性能的不同模型，模型参数可以进行大幅度优化，但参数量几乎差不多。而激活层的峰值大小则决定了这个模型在推理过程中消耗的最大内存数。<br>
<img alt="推理中参数大小与激活层大小对比图" loading="lazy" src="https://pics.zhouxin.space/202411081838915.png?x-oss-process=image/quality,q_90/format,webp"><br>
在模型训练中，激活层更是瓶颈，其大小是参数量的数倍。<br>
<img alt="训练中参数大小与激活层大小对比图" loading="lazy" src="https://pics.zhouxin.space/202411081841611.png?x-oss-process=image/quality,q_90/format,webp"><br>
在 CNN 的训练中，由于前期图像分辨率高，瓶颈在于激活层；后期特征通道数较高，瓶颈在于权重。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202411081844392.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>MAC<br>
MAC 操作指的是 multiply-accumulate，即一次 $a\leftarrow a+b\cdot c$ 操作，在 GPU 上其通常可以被翻译为一条指令。</p>
</li>
</ul>
<p>在如下所示的举证向量乘法即 MV 中，MAC 数为 $m\cdot n$，即每个结果都需要 $n$ 次 MAC 操作。 而在矩乘即 GEMM 中，MAC 数为 $m\cdot n\cdot k$，即每个结果需要 $k$ 次 MAC 操作。<br>
<img alt="矩阵-向量乘法和通用矩乘示意图" loading="lazy" src="https://pics.zhouxin.space/202411081854699.png?x-oss-process=image/quality,q_90/format,webp"><br>
不同层的 MAC 数如下所示：<br>
<img alt="不同层的MAC数" loading="lazy" src="https://pics.zhouxin.space/202411081859325.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>
<p>FLOP 浮点操作数<br>
一个 MAC 相当于两个 FLOP</p>
</li>
<li>
<p>FLOPs 每秒浮点计算次数</p>
</li>
<li>
<p>OP 操作数<br>
模型并不一定总是用浮点数进行表示和计算，对于非浮点计算，我们称之为 OP。</p>
</li>
<li>
<p>OPs 每秒操作计算次数</p>
</li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>运算符优先级解析算法之优先级爬升法——算法原理与实现</title>
      <link>https://www.zhouxin.space/notes/operator-precedence-parsing-algorithm-precedence-climbing-algorithm-and-implementation/</link>
      <pubDate>Tue, 29 Oct 2024 19:01:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/operator-precedence-parsing-algorithm-precedence-climbing-algorithm-and-implementation/</guid>
      <description>&lt;p&gt;在 LLVM 的官方入门教程 &lt;a href=&#34;https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html&#34;&gt;My First Language Frontend with LLVM Tutorial&lt;/a&gt; 的第二章构造 AST 时涉及到了对运算符优先级解析的内容，使用的算法为 &lt;a href=&#34;https://en.wikipedia.org/wiki/Operator-precedence_parser#Precedence_climbing_method&#34;&gt;优先级爬升法&lt;/a&gt;。尽管教程开篇称“不需要编译原理前置预备知识”，但直接理解代码仍有点吃力，本文为我个人对此方法的理解，难免存在错误，欢迎指正。&lt;/p&gt;
&lt;h1 id=&#34;算法原理&#34;&gt;算法原理&lt;/h1&gt;
&lt;h2 id=&#34;约定和前置知识&#34;&gt;约定和前置知识&lt;/h2&gt;
&lt;p&gt;在优先级爬升法中，中缀表达式被分解为主表达式（primary expression）和运算符（operator），例如在表达式 &lt;code&gt;a+b*c-d&lt;/code&gt; 中，主表达式为 &lt;code&gt;[&#39;a&#39;, &#39;b&#39;, &#39;c&#39;, &#39;d&#39;]&lt;/code&gt;，运算符包括 &lt;code&gt;[&#39;+&#39;, &#39;*&#39;, &#39;-&#39;]&lt;/code&gt;，每个运算符都有与之对应的优先级和结合性，优先级使用正整数表示相对大小，四则运算中乘除优先级高于加减，均为左结合。在本例中，约定加减的优先级为 10，乘除的为 20。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>在 LLVM 的官方入门教程 <a href="https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html">My First Language Frontend with LLVM Tutorial</a> 的第二章构造 AST 时涉及到了对运算符优先级解析的内容，使用的算法为 <a href="https://en.wikipedia.org/wiki/Operator-precedence_parser#Precedence_climbing_method">优先级爬升法</a>。尽管教程开篇称“不需要编译原理前置预备知识”，但直接理解代码仍有点吃力，本文为我个人对此方法的理解，难免存在错误，欢迎指正。</p>
<h1 id="算法原理">算法原理</h1>
<h2 id="约定和前置知识">约定和前置知识</h2>
<p>在优先级爬升法中，中缀表达式被分解为主表达式（primary expression）和运算符（operator），例如在表达式 <code>a+b*c-d</code> 中，主表达式为 <code>['a', 'b', 'c', 'd']</code>，运算符包括 <code>['+', '*', '-']</code>，每个运算符都有与之对应的优先级和结合性，优先级使用正整数表示相对大小，四则运算中乘除优先级高于加减，均为左结合。在本例中，约定加减的优先级为 10，乘除的为 20。</p>
<p>中缀表达式可以被解析为 <a href="https://www.geeksforgeeks.org/expression-tree/">表达式树</a>，表达式树能够反映出计算的优先级。所谓运算符优先级解析本质上就是要解析出正确的表达式树。</p>
<p><img alt="表达式树 图源：GeeksforGeeks" loading="lazy" src="https://pics.zhouxin.space/202410300948832.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>本文仅限于对双目运算符进行讨论，即每个运算符接收两个输入参数，输出一个计算结果。</p>
<h2 id="原理">原理</h2>
<p>以表达式 <code>a+b*c-d</code> 为例，当我们自左向右扫描到第一个运算符 <code>+</code> 时，由于 <code>b*c</code> 的优先级更高，因此不能直接在表达式树上构造出 <code>a+b</code>，而应该优先构造出 <code>b*c</code>。那优先级高的是不是一定要被优先计算呢？非也，例如表达式 <code>a+b+c*d</code>，即便在计算 <code>c*d</code> 之前优先计算了 <code>a+b</code>，也并不会妨碍我们构造出正确的表达式树。</p>
<p>那当我们扫描到一个运算符时，什么情况可以在表达式树上构造对应节点，什么时候需要先计算优先级更高的节点呢？我们知道，在中缀表达式中，每个主元周围最多有有两个运算符，<strong>主元需要与优先级更高的那个运算符进行结合</strong>，因此，当我们扫描的运算符时，我们可以先解析出下一个主元，以及下一个运算符，如果当前运算符的优先级高于下一个运算符（或者优先级一致，但是当前运算符是左结合的），那么说明下一个主元是当前运算符的第二个输入参数 rhs，否则说明下一个主元是下一个运算符的第一个输入参数 lhs，需要先将下一个运算符解析完毕，解析结果才是当前运算符的 rhs。</p>
<p>例如，在表达式 <code>a+b*c-d</code> 中，首先扫描到第一个主元 <code>a</code>，将其记录为 rhs；扫描到第一个运算符 <code>+</code>，继续向后扫描到一个主元 <code>b</code>，以及与其邻接的运算符 <code>*</code>，<code>*</code> 的优先级高于 <code>+</code>，因此开辟一个新的函数栈，<code>lhs = b</code>，对 <code>*</code> 进行解析；向后扫描到一个主元 <code>c</code>，以及与其邻接的运算符 <code>-</code>，<code>*</code> 的优先级高于 <code>-</code>，因此 <code>c</code> 是 <code>*</code> 的 rhs，构造出 <code>b*c</code> 并返回该结果；回到解析第一个 <code>+</code> 的函数栈中，其接收到 <code>b*c</code> 作为新的 rhs，并继续扫描下一个运算符 <code>-</code>，<code>+</code> 的优先级与 <code>-</code> 一致，且 <code>+</code> 是左结合的，因此构造出 <code>a+b*c</code>，并作为新的 lhs；继续扫描到下一个主元 <code>d</code>，并且不存在下一个运算符，则继续构造出 <code>a+b*c-d</code>，并结束解析。</p>
<p>算法伪代码如下所示 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="nf">parse_expression</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="nf">parse_expression_1</span><span class="p">(</span><span class="nf">parse_primary</span><span class="p">(),</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nf">parse_expression_1</span><span class="p">(</span><span class="n">lhs</span><span class="p">,</span> <span class="n">min_precedence</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="nl">_lookahead_</span> <span class="p">:</span><span class="o">=</span> <span class="n">peek</span> <span class="n">next</span> <span class="n">token</span>
</span></span><span class="line"><span class="cl">    <span class="k">while</span> <span class="n">_lookahead_</span> <span class="n">is</span> <span class="n">a</span> <span class="n">binary</span> <span class="n">operator</span> <span class="n">whose</span> <span class="n">precedence</span> <span class="n">is</span> <span class="o">&gt;=</span> <span class="n">_min_precedence_</span>
</span></span><span class="line"><span class="cl">        <span class="nl">_op_</span> <span class="p">:</span><span class="o">=</span> <span class="n">_lookahead_</span>
</span></span><span class="line"><span class="cl">        <span class="n">advance</span> <span class="n">to</span> <span class="n">next</span> <span class="n">token</span>
</span></span><span class="line"><span class="cl">        <span class="nl">_rhs_</span> <span class="p">:</span><span class="o">=</span> <span class="nf">_parse_primary_</span> <span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="nl">_lookahead_</span> <span class="p">:</span><span class="o">=</span> <span class="n">peek</span> <span class="n">next</span> <span class="n">token</span>
</span></span><span class="line"><span class="cl">        <span class="k">while</span> <span class="n">_lookahead_</span> <span class="n">is</span> <span class="n">a</span> <span class="n">binary</span> <span class="n">operator</span> <span class="n">whose</span> <span class="n">precedence</span> <span class="n">is</span> <span class="n">greater</span>
</span></span><span class="line"><span class="cl">                 <span class="n">than</span> <span class="n">_op_</span><span class="err">&#39;</span><span class="n">s</span><span class="p">,</span> <span class="n">or</span> <span class="n">a</span> <span class="n">right</span><span class="o">-</span><span class="n">associative</span> <span class="n">operator</span>
</span></span><span class="line"><span class="cl">                 <span class="n">whose</span> <span class="n">precedence</span> <span class="n">is</span> <span class="n">equal</span> <span class="n">to</span> <span class="n">_op</span><span class="err">&#39;</span><span class="n">_s</span>
</span></span><span class="line"><span class="cl">            <span class="nl">_rhs_</span> <span class="p">:</span><span class="o">=</span> <span class="nf">_parse_expression_1_</span> <span class="p">(</span><span class="n">_rhs_</span><span class="p">,</span> <span class="n">precedence</span> <span class="n">of</span> <span class="n">_op_</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">_lookahead_</span> <span class="n">precedence</span> <span class="n">is</span> <span class="n">greater</span><span class="p">,</span> <span class="k">else</span> <span class="mi">0</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">            <span class="nl">_lookahead_</span> <span class="p">:</span><span class="o">=</span> <span class="n">peek</span> <span class="n">next</span> <span class="n">token</span>
</span></span><span class="line"><span class="cl">        <span class="nl">_lhs_</span> <span class="p">:</span><span class="o">=</span> <span class="n">the</span> <span class="n">result</span> <span class="n">of</span> <span class="n">applying</span> <span class="n">_op_</span> <span class="n">with</span> <span class="n">operands</span> <span class="n">_lhs_</span> <span class="n">and</span> <span class="n">_rhs_</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">_lhs_</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>其中 <code>parse_expression_1</code> 是一个递归调用的函数，其接收两个参数 <code>lhs</code> 和 <code>min_precedence</code>，功能是对中缀表达式进行顺序解析，直至碰到优先级低于 <code>min_precedence</code> 返回解析结果。</p>
<p>而在 <code>parse_expression_1</code> 内部，其如果碰到了可以直接计算的情形，即当前运算符的优先级更高，或者相等且为左结合，则将 <code>lhs __0p__ rhs</code> 的结果作为新的 <code>lhs</code>，并进行下一趟外部循环；如果需要先计算下一个右结合，则递归调用自身，并将 <code>min_precedence</code> 参数设置为当前运算符的优先级 +0/1。</p>
<p>+0 是为了正确处理右结合的情况，当下一个运算符与当前运算符优先级相当且右结合时，<code>min_precedence</code> 需要设置为与当前运算符优先级相等的值，以确保递归调用时碰到 <code>1^2^3</code>（右结合运算）时能够继续递归解析，而非像左结合一样直接返回；+1 是为了正确处理左结合的情况，以确保当碰到与当前优先级相同的运算符时其能够及时返回，而不是继续向后解析。</p>
<h2 id="括号处理">括号处理</h2>
<p>上述算法碰到括号就扑街了：一方面，括号不是双目运算符，无法在小修的情况在融入我们的算法；另一方面，括号拥有着最高的优先级，意味着我们需要对其进行特殊处理。好在 LLVM 的 tutorial 中提供了另一种解决思路：括号内的内容本身就是一个表达式，调用 <code>parse_expression()</code> 函数对括号内的内容进行解析即可，并将解析结果当作我们算法中的一个主元即可。</p>
<p>具体实现见下一章。</p>
<h1 id="实现">实现</h1>
<p>本章主要讲解 LLVM tutorial 对优先级爬升法的实现，在教程中实现了对 <a href="https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl01.html#id1">Kaleidoscope</a> 语言的词法和语法分析，该语言支持基本四则运算。在第二章中需要代码解析为抽象语法树 AST，首先介绍几个脚手架，用于进行词法分析。</p>
<p><code>ParsePrimary</code> 用于解析前文中提到的主表达式，主表达式可能是标识符、数字、括号包围的表达式：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">ExprAST</span><span class="o">&gt;</span> <span class="n">ParsePrimary</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">    <span class="k">switch</span> <span class="p">(</span><span class="n">CurTok</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">default</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">LogError</span><span class="p">(</span><span class="s">&#34;Unknown token when expecting an expression&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">case</span> <span class="nl">tok_identifier</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">ParseIdentifierExpr</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">case</span> <span class="nl">tok_number</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">ParseNumberExpr</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">case</span> <span class="sc">&#39;(&#39;</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">ParseParenExpr</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>运算符的优先级使用 <code>map</code> 来记录：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">map</span><span class="o">&lt;</span><span class="kt">char</span><span class="p">,</span> <span class="kt">int</span><span class="o">&gt;</span> <span class="n">BinopPredence</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">GetTokPrecidence</span><span class="p">()</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">BinopPredence</span><span class="p">.</span><span class="n">empty</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">        <span class="n">InstallBinopPredence</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="n">isascii</span><span class="p">(</span><span class="n">CurTok</span><span class="p">)){</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">TokPrec</span> <span class="o">=</span> <span class="n">BinopPredence</span><span class="p">[</span><span class="n">CurTok</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">TokPrec</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)</span>   <span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">TokPrec</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">InstallBinopPredence</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// 1 is the lowest predence
</span></span></span><span class="line"><span class="cl">    <span class="n">BinopPredence</span><span class="p">[</span><span class="sc">&#39;&lt;&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">10</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">BinopPredence</span><span class="p">[</span><span class="sc">&#39;&gt;&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">10</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">BinopPredence</span><span class="p">[</span><span class="sc">&#39;+&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">20</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">BinopPredence</span><span class="p">[</span><span class="sc">&#39;-&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">20</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">BinopPredence</span><span class="p">[</span><span class="sc">&#39;*&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">30</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">BinopPredence</span><span class="p">[</span><span class="sc">&#39;/&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">30</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>ParseBinOpRHS</code> 即为优先级爬升法的核心实现，对应伪代码中的 <code>parse_expression_1</code>，注意看注释：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">ExprAST</span><span class="o">&gt;</span> <span class="n">ParseBinOpRHS</span><span class="p">(</span><span class="kt">int</span> <span class="n">ExprPrec</span><span class="p">,</span> 
</span></span><span class="line"><span class="cl">                                        <span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">ExprAST</span><span class="o">&gt;</span> <span class="n">LHS</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">while</span> <span class="p">(</span><span class="nb">true</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">TokPrec</span> <span class="o">=</span> <span class="n">GetTokPrecidence</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">		<span class="c1">// 如果当前Token不是运算符，说明解析结束，TokPrec=-1
</span></span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">TokPrec</span> <span class="o">&lt;</span> <span class="n">ExprPrec</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">LHS</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">BinOp</span> <span class="o">=</span> <span class="n">CurTok</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">getNextToken</span><span class="p">();</span> <span class="c1">// eat BinOp
</span></span></span><span class="line"><span class="cl">        <span class="k">auto</span> <span class="n">RHS</span> <span class="o">=</span> <span class="n">ParsePrimary</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="n">RHS</span><span class="p">)</span> <span class="k">return</span> <span class="k">nullptr</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">NextPrec</span> <span class="o">=</span> <span class="n">GetTokPrecidence</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">TokPrec</span> <span class="o">&lt;</span> <span class="n">NextPrec</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	        <span class="c1">// 不考虑右结合
</span></span></span><span class="line"><span class="cl">            <span class="n">RHS</span> <span class="o">=</span> <span class="n">ParseBinOpRHS</span><span class="p">(</span><span class="n">TokPrec</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">RHS</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">	    <span class="c1">// 构造 BinOp 对应的AST
</span></span></span><span class="line"><span class="cl">        <span class="n">LHS</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">make_unique</span><span class="o">&lt;</span><span class="n">BinaryExprAST</span><span class="o">&gt;</span><span class="p">(</span><span class="n">BinOp</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">LHS</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">RHS</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>解析表达式的主函数为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">ExprAST</span><span class="o">&gt;</span> <span class="n">ParseExpression</span><span class="p">()</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">auto</span> <span class="n">LHS</span> <span class="o">=</span> <span class="n">ParsePrimary</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="n">LHS</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="k">nullptr</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="nf">ParseBinOpRHS</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">LHS</span><span class="p">));</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="参考">参考</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://en.wikipedia.org/wiki/Operator-precedence_parser#Pseudocode">Operator-precedence parser - Wikipedia</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>Programming Massively Parallel Processors A Hands-on Approach 4th Edition 学习笔记 Part 2</title>
      <link>https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-2/</link>
      <pubDate>Thu, 10 Oct 2024 20:09:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-2/</guid>
      <description>&lt;p&gt;若无另外声明，本文图片均截取自原书。&lt;/p&gt;
&lt;h1 id=&#34;chapter-07-convolution-卷积&#34;&gt;Chapter 07: Convolution 卷积&lt;/h1&gt;
&lt;p&gt;本章主要介绍 2D 卷积实现，从朴素版本开始，分别使用常量内存、分块共享内存和 cache 技术依次进行优化。&lt;/p&gt;
&lt;h2 id=&#34;71-background-背景&#34;&gt;7.1 Background 背景&lt;/h2&gt;
&lt;p&gt;卷积的定义此处不再赘述，简单来说就是对某个元素及其相邻元素进行加权求和。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>若无另外声明，本文图片均截取自原书。</p>
<h1 id="chapter-07-convolution-卷积">Chapter 07: Convolution 卷积</h1>
<p>本章主要介绍 2D 卷积实现，从朴素版本开始，分别使用常量内存、分块共享内存和 cache 技术依次进行优化。</p>
<h2 id="71-background-背景">7.1 Background 背景</h2>
<p>卷积的定义此处不再赘述，简单来说就是对某个元素及其相邻元素进行加权求和。</p>
<h2 id="72-parallel-convolution-a-basic-algorithm-并行卷积">7.2 Parallel convolution: a basic algorithm 并行卷积</h2>
<p>本节将以 2D 卷积为例进行学习。</p>
<p>注意到卷积运算彼此独立，因此可以按照每个线程负责一个元素计算的方式写出并行版本的卷积核。首先确定参数列表：输入矩阵指针 <code>N</code>，卷积核指针 <code>F</code>，输出矩阵指针 <code>P</code>，卷积核半径 <code>r</code>，输入矩阵高宽 <code>height</code> 和 <code>width</code>。</p>
<p>然后确定线程和输出元素之间的映射关系。鉴于输出矩阵是个二维矩阵，因此可以将线程也组织为二维形式，并且每个线程负责计算一个元素。每个 block 最多有 1024 个线程，因此最多计算 1024 个元素。对应核函数为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">convolution_2D_basic_kernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">N</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">F</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">P</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">r</span><span class="p">,</span> <span class="kt">int</span> <span class="n">width</span><span class="p">,</span> <span class="kt">int</span> <span class="n">height</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">outCol</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">outRow</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mf">0.0f</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fRow</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fRow</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="o">*</span><span class="n">r</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="n">fRow</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fCol</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fCol</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="o">*</span><span class="n">r</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="n">fCol</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">inRow</span> <span class="o">=</span> <span class="n">outRow</span> <span class="o">-</span> <span class="n">r</span> <span class="o">+</span> <span class="n">fRow</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="n">inCol</span> <span class="o">=</span> <span class="n">outCol</span> <span class="o">-</span> <span class="n">r</span> <span class="o">+</span> <span class="n">fCol</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="p">(</span><span class="n">inRow</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">inRow</span> <span class="o">&lt;</span> <span class="n">height</span> <span class="o">&amp;&amp;</span> <span class="n">inCol</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">inCol</span> <span class="o">&lt;</span> <span class="n">width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">F</span><span class="p">[</span><span class="n">fRow</span><span class="o">*</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">r</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">fCol</span><span class="p">]</span><span class="o">*</span><span class="n">N</span><span class="p">[</span><span class="n">inRow</span><span class="o">*</span><span class="n">width</span> <span class="o">+</span> <span class="n">inCol</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">P</span><span class="p">[</span><span class="n">outRow</span><span class="o">*</span><span class="n">width</span> <span class="o">+</span> <span class="n">outCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>该核函数通过两层循环对感受野进行遍历，使用寄存器变量 <code>Pvalue</code> 进行暂存，使用一个 <code>if</code> 进行感受野边界判断。</p>
<p>不难发现，上述代码存在控制流分歧。处理四周边界的线程在条件判断中存在分歧。分歧的影响程度取决于矩阵的大小，对于较大输入和较小卷积核，分歧的比例很小，反之影响很大。</p>
<p>另一个更为严峻的影响因素是内存带宽，上述代码浮点操作数和访存量的带宽比值为 0.25 OP/B（第 11 行的两次计算比上两次 8 字节浮点数访存）。这使得访存大大拖累了计算过程。</p>
<h2 id="73-constant-memory-and-caching-常量内存和缓存">7.3 Constant memory and caching 常量内存和缓存</h2>
<p>在卷积中，卷积核有三个良好性质：1️⃣ 卷积核通常都比较小，其半径不超过 7，即便是 3D 卷积中权重数量也不超过 7 的立方即 343 个元素；2️⃣ 在卷积过程中，卷积核权重不会变化；3️⃣ 所有线程都按照相同的次序访问同一个卷积核。</p>
<p>上述三个特性使得卷积核非常适合放在常量内存和缓存中。常量内存在核函数执行过程中不能被修改，且只有 64KB 大小。常量内存需要在主机端进行申请和拷贝，假设使用编译时常量 <code>FILTER_RADIUS</code> 来指定核函数半径，则使用如下代码声明常量内存：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="cp">#define FILTER_RADIUS 2
</span></span></span><span class="line"><span class="cl"><span class="n">__constant__</span> <span class="kt">float</span> <span class="n">F</span><span class="p">[</span><span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">];</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>需要注意的是常量内存必须在全局作用域中声明，即不能在主机函数中进行声明。</p>
<p>使用 <code>cudaMemcpyToSymbol</code> 函数将数据从主机拷贝到常量内存中：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="nf">cudaMemcpyToSymbol</span><span class="p">(</span><span class="n">F</span><span class="p">,</span> <span class="n">F_h</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>其中 <code>F_h</code> 表示主机上的 F。</p>
<p>保存在常量内存上的变量是全局变量，因此不需要将卷积核作为参数传给核函数，因此相比第一版核函数，除了函数签名外，几乎不需要修改：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">convolution_2D_const_mem_kernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">N</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">P</span><span class="p">,</span> <span class="kt">int</span> <span class="n">r</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                                <span class="kt">int</span> <span class="n">width</span><span class="p">,</span> <span class="kt">int</span> <span class="n">height</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">outCol</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">outRow</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mf">0.0f</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fRow</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fRow</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">r</span> <span class="o">+</span> <span class="mi">1</span><span class="p">;</span> <span class="n">fRow</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fCol</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fCol</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">r</span> <span class="o">+</span> <span class="mi">1</span><span class="p">;</span> <span class="n">fCol</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="kt">int</span> <span class="n">inRow</span> <span class="o">=</span> <span class="n">outRow</span> <span class="o">-</span> <span class="n">r</span> <span class="o">+</span> <span class="n">fRow</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="kt">int</span> <span class="n">inCol</span> <span class="o">=</span> <span class="n">outCol</span> <span class="o">-</span> <span class="n">r</span> <span class="o">+</span> <span class="n">fCol</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="p">(</span><span class="n">inRow</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">inRow</span> <span class="o">&lt;</span> <span class="n">height</span> <span class="o">&amp;&amp;</span> <span class="n">inCol</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">inCol</span> <span class="o">&lt;</span> <span class="n">width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">F</span><span class="p">[</span><span class="n">fRow</span><span class="p">][</span><span class="n">fCol</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span><span class="p">[</span><span class="n">inRow</span> <span class="o">*</span> <span class="n">width</span> <span class="o">+</span> <span class="n">inCol</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">P</span><span class="p">[</span><span class="n">outRow</span> <span class="o">*</span> <span class="n">width</span> <span class="o">+</span> <span class="n">outCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>CUDA C 中变量作用域遵循 C 语言规则，因此如果分文件声明和引用全局变量，需要使用 <code>extern</code> 关键字进行外部引用。</p>
<p>常量内存变量也保存在 DRAM 中，但是由于已知该变量在核函数运行时不可变，因此运行时将指导硬件对其采用更激进的 cache 策略。</p>
<p>与共享内存或者寄存器不同，cache 对程序员是不可见的，其由硬件和运行时控制。cache 成本相当昂贵，尤其是，如果需要支持写操作。而常量变量不可写入且比较小的特性，使得在硬件上能够以较低的代价实现常量缓存即 constant cache。</p>
<p>在引入常量内存之后，浮点操作数和访存量的带宽比值翻了个翻，达到了 0.5 OP/B。</p>
<h2 id="74-tiled-convolution-with-halo-cells-带有边界单元的分块卷积">7.4 Tiled convolution with halo cells 带有边界单元的分块卷积</h2>
<p>分块卷积可以缓解内存瓶颈。首先来定义输入和输出分块的概念。输出矩阵中的一块指的是一个 block 中所有线程计算的元素的集合，如果由输出矩阵每个元素有一个线程负责计算，每个 block 包含 16 个线程，那么输出矩乘就是按照每块 4×4 进行分块。当然在实际中每个 block 至少要有一个线程束那么多线程，以便最大化占用率和数据复用率。</p>
<p>自然地，输入块就被定义为计算一个输出块需要用到的元素集合。如下图所示，如果卷积核半径为 2，那么输入块为蓝色部分（深蓝和浅蓝），输出块为绿色部分。其中，浅蓝色被称为 halo cells 即边界单元。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410112036856.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>进行分块时候，就首先由同一个 block 内的进行将数据通过合并访存将其读入共享内存。注意到输入内存和输出内存大小存在差异，有两种线程组织方式来应对这一差异。第一种是启动与输入块元素数量相同的线程，这种方式便于加载数据，但在计算时需要闲置部分线程。另一种方式是启动与输出块相同的线程，这种方式在加载数据阶段较为复杂，但是整体线程利用率更高。本书将以方式一为例。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="cp">#define IN_TILE_DIM 32
</span></span></span><span class="line"><span class="cl"><span class="cp">#define OUT_TILE_DIM ((IN_TILE_DIM) - 2 * (FILTER_RADIUS))
</span></span></span><span class="line"><span class="cl"><span class="n">__constant__</span> <span class="kt">float</span> <span class="n">F_c</span><span class="p">[</span><span class="mi">2</span> <span class="o">*</span> <span class="n">FILTER_RADIUS</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">2</span> <span class="o">*</span> <span class="n">FILTER_RADIUS</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">convolution_tiled_2D_const_mem_kernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">N</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">P</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                                      <span class="kt">int</span> <span class="n">width</span><span class="p">,</span> <span class="kt">int</span> <span class="n">height</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">-</span> <span class="n">FILTER_RADIUS</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">-</span> <span class="n">FILTER_RADIUS</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Loading input tile
</span></span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="n">N_s</span><span class="p">[</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">row</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">height</span> <span class="o">&amp;&amp;</span> <span class="n">col</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">col</span> <span class="o">&lt;</span> <span class="n">width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">N_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">N</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">width</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">N_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Calculating output elements
</span></span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">tileCol</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">-</span> <span class="n">FILTER_RADIUS</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">tileRow</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">-</span> <span class="n">FILTER_RADIUS</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Turning off the threads at the edges of the block
</span></span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">col</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">col</span> <span class="o">&lt;</span> <span class="n">width</span> <span class="o">&amp;&amp;</span> <span class="n">row</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">height</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">tileCol</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">tileCol</span> <span class="o">&lt;</span> <span class="n">OUT_TILE_DIM</span> <span class="o">&amp;&amp;</span> <span class="n">tileRow</span> <span class="o">&gt;=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">            <span class="o">&amp;&amp;</span> <span class="n">tileRow</span> <span class="o">&lt;</span> <span class="n">OUT_TILE_DIM</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mf">0.0f</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fRow</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fRow</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">FILTER_RADIUS</span> <span class="o">+</span> <span class="mi">1</span><span class="p">;</span> <span class="n">fRow</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fCol</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fCol</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">FILTER_RADIUS</span> <span class="o">+</span> <span class="mi">1</span><span class="p">;</span> <span class="n">fCol</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">F</span><span class="p">[</span><span class="n">fRow</span><span class="p">][</span><span class="n">fCol</span><span class="p">]</span> <span class="o">*</span> <span class="n">N_s</span><span class="p">[</span><span class="n">tileRow</span> <span class="o">+</span> <span class="n">fRow</span><span class="p">][</span><span class="n">tileCol</span> <span class="o">+</span> <span class="n">fCol</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="n">P</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">width</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>代码如上所示，这部分代码做到了 self-explain，不再解释。</p>
<p>接下来计算上述代码的浮点操作数和访存量的带宽比值，这里也将边界线程不用加载 ghost cell 当作一次访存。在每一个 block 内部，其需要加载 <code>IN_TILE_DIM*IN_TILE_DIM</code> 个浮点数到共享内存中，进行了 <code>OUT_TILE_DIM*OUT_TIME_DIM*(2*FILTER_RADIUM+1)*(2*FILTER_RADIUM+1)</code> 浮点数运算。对于 32×32 的输入和 5×5 的卷积核，比值为 9.57 OP/B。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410112305262.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>上表展示了不同输入维度对应的浮点操作数和访存量的带宽比值，不难发现卷积核越大，该比值越高。</p>
<h2 id="75-tiled-convolution-using-caches-for-halo-cells-为边界元素使用缓存的空洞卷积">7.5 Tiled convolution using caches for halo cells 为边界元素使用缓存的空洞卷积</h2>
<p>注意到如下事实：一块的 halo cells 可能是另一块的内部元素，因此当一块在试图访问其 halo cells 时，很有可能其已经被加载到 L2 cache 中。应用如上特性，本章将介绍一种具有相同输入和输入 tile size 的分块卷积算法，其只把内部元素加载到共享内存，而不显式加载 halo cells。</p>
<p>代码如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="cp">#define TILE_DIM 32
</span></span></span><span class="line"><span class="cl"><span class="n">__constant__</span> <span class="kt">float</span> <span class="n">F</span><span class="p">[</span><span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">convolution_cached_tiled_2D_const_mem_kernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">N</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                                            <span class="kt">float</span> <span class="o">*</span><span class="n">P</span><span class="p">,</span> <span class="kt">int</span> <span class="n">width</span><span class="p">,</span> <span class="kt">int</span> <span class="n">height</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="c1">//loading input tile
</span></span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="n">N_s</span><span class="p">[</span><span class="n">TILE_DIM</span><span class="p">][</span><span class="n">TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">row</span><span class="o">&lt;</span><span class="n">height</span> <span class="o">&amp;&amp;</span> <span class="n">col</span><span class="o">&lt;</span><span class="n">width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">N_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">N</span><span class="p">[</span><span class="n">row</span><span class="o">*</span><span class="n">width</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">N_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Calculating output elements
</span></span></span><span class="line"><span class="cl">    <span class="c1">// turning off the threads at the edges of the block
</span></span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">col</span> <span class="o">&lt;</span> <span class="n">width</span> <span class="o">&amp;&amp;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">height</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mf">0.0f</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fRow</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fRow</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="n">fRow</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">fCol</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">fCol</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="o">*</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="n">fCol</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fCol</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span>
</span></span><span class="line"><span class="cl">                    <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fCol</span> <span class="o">&lt;</span> <span class="n">TILE_DIM</span> <span class="o">&amp;&amp;</span>
</span></span><span class="line"><span class="cl">                    <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fRow</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span>
</span></span><span class="line"><span class="cl">                    <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fRow</span> <span class="o">&lt;</span> <span class="n">TILE_DIM</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">F</span><span class="p">[</span><span class="n">fRow</span><span class="p">][</span><span class="n">fCol</span><span class="p">]</span><span class="o">*</span><span class="n">N_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">+</span><span class="n">fRow</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">+</span><span class="n">fCol</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">                <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="k">if</span> <span class="p">(</span><span class="n">row</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fRow</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span>
</span></span><span class="line"><span class="cl">                        <span class="n">row</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fRow</span> <span class="o">&lt;</span> <span class="n">height</span> <span class="o">&amp;&amp;</span>
</span></span><span class="line"><span class="cl">                        <span class="n">col</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fCol</span> <span class="o">&gt;=</span><span class="mi">0</span> <span class="o">&amp;&amp;</span>
</span></span><span class="line"><span class="cl">                        <span class="n">col</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fCol</span> <span class="o">&lt;</span> <span class="n">width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                        <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">F</span><span class="p">[</span><span class="n">fRow</span><span class="p">][</span><span class="n">fCol</span><span class="p">]</span><span class="o">*</span>
</span></span><span class="line"><span class="cl">                            <span class="n">N</span><span class="p">[(</span><span class="n">row</span><span class="o">-</span><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fRow</span><span class="p">)</span><span class="o">*</span><span class="n">width</span><span class="o">+</span><span class="n">col</span><span class="o">-</span>
</span></span><span class="line"><span class="cl"><span class="n">FILTER_RADIUS</span><span class="o">+</span><span class="n">fCol</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                    <span class="p">}</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">P</span><span class="p">[</span><span class="n">row</span><span class="o">*</span><span class="n">width</span><span class="o">+</span><span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在代码中既要进行 ghost cells 判断，也要进行 halo cells 判断。通过两层 for 循环遍历感受野，在循环内部首先判断是否为内部元素，如果是 halo cells，则继续判断是否为 ghost cells。</p>
<h1 id="chapter-08-stencil-模板计算">Chapter 08: Stencil 模板计算</h1>
<p>注意，本章中 Stencil 模板指的是一种计算模式，常用于科学计算领域，与 C++ 中的 template 是完全不同的两个概念。stencil 用于计算一系列具有物理意义的离散量，其与卷积操作有相通之处，即同意一个元素及其周围元素计算新值。与之不同的是，用于计算的元素和对应的权重由微分方程。此外，在迭代过程中，输出值可能取决于边界条件，stencil 计算可能具有依赖性，并且科学计算往往要求更高的浮点精度。这些区别决定了 stencil 和卷积具有不同的优化技术。</p>
<h2 id="81-background-背景">8.1 Background 背景</h2>
<p>使用计算机进行数值计算的第一步就是将其离散化。我们使用结构化网格对 n 维欧式空间进行规则划分，在一维中使用线段、二维使用矩形、三维使用长方体。下图中对一维函数 $y=\sin (x)$ 按照长度为 $\pi/6$ 进行了划分。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410150922860.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>在离散表示中，不再网格点上的值要使用例如线性插值、样条插值技术通过周围网格点计算得出。计算精度取决于网格的密度，密度越大越精确。精度还取决于数据表示的精度，例如双精度浮点数的精度大于半精度浮点数，但更高的精度意味着消耗更多的片上内存，可能构成计算瓶颈。</p>
<p>模板制定了如何通过一点及其周围点的值通过有限差分的方法计算该点的其它数学量，而偏微分方程则制定了该数学量的具体表达式。例如，计算一维函数的一阶导数有一个经典的方法是：</p>


<div>$$

f^\prime(x) = \frac{f(x&#43;h)-f(x-h)}{2h} &#43; O(h^2)

$$</div>

<p>其中 $O(h^2)$ 是误差项，从中可以看出，误差取决于网格划分的密度。</p>
<p>假设 <code>F[i]</code> 是保存函数值的数组，需要计算一阶导数 <code>FD[i]</code>，显然可以通过表达式 <code>FD[i] = (F[i+1]-f[i-1])/(2*h)</code> 进行迭代计算，进一步地，可以等价转换为 <code>FD[i] = F[i+1]/(2*h)-F[i-1]/(2h)</code>，上述表达式可以记为对 <code>[i-1, i, i+1]</code> 按照权重 <code>[-1/2h, 0, 1/2h]</code> 进行 stencil 操作。</p>
<p>显然，如果要计算偏微分方程，则需要使用多维网格进行划分和计算。</p>
<p>在本章中，我们主要关注一种计算模式：stencil 将被应用到全局以计算全局所有数学量的值，这类计算模式被称为模板扫描 stencil sweep。</p>
<h2 id="82-parallel-stencil-a-basic-algorithm-一种基本算法并行模板">8.2 Parallel stencil: a basic algorithm 一种基本算法：并行模板</h2>
<p>假定在一次 stencil sweep 中输出元素之间彼此独立，并且网格边界元素保存了这个微分方程的边界值，在单个 sweep 中不会修改。例如，在下图中输出部分的阴影就是所谓的边界值，其在 sweep 中不会被修改。上述假设是有意义的，因为 stencil 主要用于有边界的微分方程问题。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410151057594.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>下述代码展示了一个计算 3d stencil 的核函数，每个 block 负责计算 output 的一个 tile，每个 thread 负责计算 tile 中的一个元素。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">stencil_kernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">in</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">z</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">z</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">c0</span><span class="o">*</span><span class="n">in</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="o">+</span> <span class="n">c1</span><span class="o">*</span><span class="n">in</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">            <span class="o">+</span> <span class="n">c2</span><span class="o">*</span><span class="n">in</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">            <span class="o">+</span> <span class="n">c3</span><span class="o">*</span><span class="n">in</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="p">(</span><span class="n">j</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="o">+</span> <span class="n">c4</span><span class="o">*</span><span class="n">in</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="o">+</span> <span class="n">c5</span><span class="o">*</span><span class="n">in</span><span class="p">[(</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="o">+</span> <span class="n">c6</span><span class="o">*</span><span class="n">in</span><span class="p">[(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码的浮点操作与访存比为：13/(7*4) = 0.46 OP/B。</p>
<h2 id="83-shared-memory-tiling-for-stencil-sweep-为模板扫描进行共享内存分块">8.3 Shared memory tiling for stencil sweep 为模板扫描进行共享内存分块</h2>
<p>在 stencil 上进行共享内存分块与卷积类似，但也有一些微妙的不同。下图展示了计算一个 output 中的 tile 中 stencil 涉及到的输入，与卷积不同的是，四个角落并不需要被使用。在进行寄存器分片时，这点尤其重要。对于共享内存分块，这一特性也会导致共享内存优化效果弱于卷积版本，这是由于不同线程复用的元素个数相比卷积更少了。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410151129310.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>共享内存的优化的上限，会随着维度和阶数（类似于卷积中的半径 radius）显著减小。例如，对于 2d stencil 来说，一阶对应 3*3 卷积，理论上限分别为 2.5 OP/B 和 4.5 OP/B，二阶对应 5*5 卷积，理论上限分别为 4.5 OP/B 和 12.5 OP/B，三阶对应 7*7 卷积，理论上限分别为 6.5 OP/B 和 24.5 OP/B。而对于 3d stencil，这一效应要显著得多得多，3d 三阶 stencil 对应半径为 7 的 3d 卷积，理论上限分别为 9.5 OP/B 和 171.5 OP/B。</p>
<p>使用共享内存优化后的代码如下所示：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global</span> <span class="kt">void</span> <span class="nf">stencil_kernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">in</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">z</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">in_s</span><span class="p">[</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">in</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span><span class="o">-</span><span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="o">-</span><span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span> <span class="o">&lt;</span> <span class="n">IN_TILE_DIM</span><span class="o">-</span><span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">&gt;=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">           <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">&lt;</span><span class="n">IN_TILE_DIM</span><span class="o">-</span><span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">&gt;=</span><span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">&lt;</span><span class="n">IN_TILE_DIM</span><span class="o">-</span><span class="mi">1</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">            <span class="n">out</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">c0</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="o">+</span> <span class="n">c1</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="o">+</span> <span class="n">c2</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="o">+</span> <span class="n">c3</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="o">+</span> <span class="n">c4</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="o">+</span> <span class="n">c5</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="o">+</span> <span class="n">c6</span><span class="o">*</span><span class="n">in_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码中，<code>ijk</code> 标识了本线程负责加载的元素在 <code>in</code> 矩阵中的索引，同时也标识着本线程负责计算的元素在 <code>out</code> 中的索引，其中负责加载 halo 和 ghost cell 的 thread 不需要计算输出元素。第 10 行 <code>if</code> 用于排除计算边界值的线程，第 11 行 <code>i</code>f 用于排除加载 halo 和 ghost cell 的线程。</p>
<p>上述代码的 OP/B 值计算过程为：假设 input tile 每个维度的 length 为 T，那么 output tile 的每个维度的 length 为 T-2，每个 block 负责计算 (T-2)^3 个元素，共有 13*(T-2)^3 个浮点运算；而每个 block 需要加载 T^3 个元素，因此 OP/B 值为 $\frac{13}{4}\times (1-\frac{2}{T})^3$。</p>
<p>T 越大，OP/B 值越大，理论上限为 13/4 = 3.25 OP/B。由于 block 中线程数量限制，T 最大取 8，此时尚未考虑共享内存限制。当 T 为 8 时，OP/B 仅为 1.37，这是由于 halo 元素在 3d 模板扫描中占比过大，halo 元素的复用率远低于内部元素。</p>
<p>T 较小的另一个缺陷是无法充分利用内存合并访问技术，对于 8×8×8 的 tile 来说，每个线程束都会加载来自 input 不同行的元素，而无法利用内存合并访问。</p>
<h2 id="84-thread-coarsening-线程粗化">8.4 Thread coarsening 线程粗化</h2>
<p>上节提到，共享内存技术在 stencil sweep 上加速效果并不显著，这是由于线程之间复用元素的比例小。本节，将通过线程粗化技术，提高粗化后的线程间的元素复用比例以克服原有缺陷。</p>
<p>假设输入 tile 为 6×6×6，如下图左所示（上面、前面、左面的一层被移除），输出 tile 为 4×4×4，如下图右绿色所示。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410172002187.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>每个 block 中线程的数量与 x-y 平面中元素数量相同，即有 4*4=16 个线程。对应核函数的实现代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">stencil_kernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">in</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">iStart</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">z</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">inPrev_s</span><span class="p">[</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">inCurr_s</span><span class="p">[</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">inNext_s</span><span class="p">[</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">iStart</span><span class="o">-</span><span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">iStart</span><span class="o">-</span><span class="mi">1</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">inPrev_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">in</span><span class="p">[(</span><span class="n">iStart</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">iStart</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">iStart</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">in</span><span class="p">[</span><span class="n">iStart</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">iStart</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">iStart</span> <span class="o">+</span> <span class="n">OUT_TILE_DIM</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">inNext_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">in</span><span class="p">[(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span><span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">&lt;</span> <span class="n">IN_TILE_DIM</span> <span class="o">-</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">               <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">IN_TILE_DIM</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">out</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">c0</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c1</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c2</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c3</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c4</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c5</span><span class="o">*</span><span class="n">inPrev_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c6</span><span class="o">*</span><span class="n">inNext_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="n">inPrev_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">inNext_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码中，在 z 方向上进行迭代，使用三个共享内存分别保存计算计算当前元素在 z 方向上需要的三层元素。</p>
<p>通过将 z 轴上多个线程合并为一个线程来实现线程粗化，这使得每个 block 需要的线程数从 T^3 减少为 T^2，因此 T 可以取到更大的值，例如 32。此时 OP/B 值达到了 2.68 OP/B，对共享内存的需求也从完整的 tile 减少为 tile 中的三层。</p>
<h2 id="85-register-tiling-寄存器分片">8.5 Register tiling 寄存器分片</h2>
<p>观察上一节代码中 out 的计算公式，不难发现 <code>inPrev</code> 和 <code>inNext</code> 这两个共享内存各自只被访问了一个元素，因此，我们只需要使用两个寄存器变量保存二者即可。此外，额外使用一个寄存器变量用于保存 <code>inCurr_s[threadIdx.y][threadIdx.x]</code>，以加快两个寄存器变量的更新。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">out</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">c0</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c1</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c2</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c3</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c4</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c5</span><span class="o">*</span><span class="n">inPrev_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c6</span><span class="o">*</span><span class="n">inNext_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>使用寄存器分片优化后的代码如下所示：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">stencil_kernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">in</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">iStart</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">z</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">OUT_TILE_DIM</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">inPrev</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">inCurr_s</span><span class="p">[</span><span class="n">IN_TILE_DIM</span><span class="p">][</span><span class="n">IN_TILE_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">inCurr</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">inNext</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">iStart</span><span class="o">-</span><span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">iStart</span><span class="o">-</span><span class="mi">1</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">inPrev</span> <span class="o">=</span> <span class="n">in</span><span class="p">[(</span><span class="n">iStart</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">iStart</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">iStart</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">inCurr</span> <span class="o">=</span> <span class="n">in</span><span class="p">[</span><span class="n">iStart</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">inCurr</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">iStart</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">iStart</span> <span class="o">+</span> <span class="n">OUT_TILE_DIM</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">inNext</span> <span class="o">=</span> <span class="n">in</span><span class="p">[(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span><span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">&lt;</span> <span class="n">IN_TILE_DIM</span> <span class="o">-</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">               <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="o">&amp;&amp;</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">IN_TILE_DIM</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">out</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">N</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">j</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">c0</span><span class="o">*</span><span class="n">inCurr</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c1</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c2</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c3</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c4</span><span class="o">*</span><span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c5</span><span class="o">*</span><span class="n">inPrev</span>
</span></span><span class="line"><span class="cl">                    <span class="o">+</span> <span class="n">c6</span><span class="o">*</span><span class="n">inNext</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="n">inPrev</span> <span class="o">=</span> <span class="n">inCurr</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">inCurr</span> <span class="o">=</span> <span class="n">inNext</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">inCurr_s</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">][</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">inNext_s</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>寄存器优化减少了三分之二的共享内存的使用量，但是并没有减少对全局内存的访存次数。</p>
<h1 id="chapter-09-parallel-histogram-并行直方图">Chapter 09: Parallel histogram 并行直方图</h1>
<p>本章以直方图计算为例，引入了结果输出位置与数据相关的计算模式，介绍了原子操作及其优劣，使用私有化、粗化和聚合等优化技术进行优化。</p>
<h2 id="91-background-背景">9.1 Background 背景</h2>
<p>对直方图📊的介绍略。</p>
<p>直方图的顺序计算代码如下所示，比较简单：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">histogram_sequential</span><span class="p">(</span><span class="kt">char</span> <span class="o">*</span><span class="n">data</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">length</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                          <span class="kt">unsigned</span> <span class="kt">int</span> <span class="o">*</span><span class="n">histo</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">length</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">alphabet_position</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="sc">&#39;a&#39;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">alphabet_position</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">alphabet_position</span> <span class="o">&lt;</span> <span class="mi">26</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">      <span class="n">histo</span><span class="p">[</span><span class="n">alphabet_position</span><span class="o">/</span><span class="mi">4</span><span class="p">]</span><span class="o">++</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="92-atomic-operations-and-a-basic-histogram-kernel-原子操作和一个基本的直方图核函数">9.2 Atomic operations and a basic histogram kernel 原子操作和一个基本的直方图核函数</h2>
<p>最简单的直方图核函数就是起与元素个数数量相等的线程，每个线程负责对其对应的元素进行归类，这种情况下多个线程可能需要同一个输出参数进行更新，这种冲突被称为输出干扰。此时涉及到了原子操作和条件竞争的概念。</p>
<p>条件竞争指的是多线程同时对结果进行更新，这使得结果取决于这些线程的执行顺序。原子操作指的是独占式地完成 read-modefy-wirte 操作。本节花了大段用于说明什么是条件竞争和原子操作，在 OS 中学过这些概念，此处省略。</p>
<p>CUDA 中提供了一系列支持原子操作的内建函数，其以 <code>atomicXxx</code> 进行命名。</p>
<p>现代编译器中往往提供了一系列特殊指令用于支持某些特定功能，例如原子操作或者向量化，其对于程序员来说可能以库函数的形式被调用，但在编译层面该库函数调用不存在函数调用过程，而是直接被编译为对应的编译器指令。</p>
<p>应用原子操作后的直方图核函数如下所示：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">histo_kernel</span><span class="p">(</span><span class="kt">char</span> <span class="o">*</span><span class="n">data</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">length</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="o">*</span><span class="n">histo</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">&lt;</span> <span class="n">length</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">alphabet_position</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="sc">&#39;a&#39;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">alphabet_position</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">alpha_position</span> <span class="o">&lt;</span> <span class="mi">26</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo</span><span class="p">[</span><span class="n">alphabet_position</span><span class="o">/</span><span class="mi">4</span><span class="p">]),</span> <span class="mi">1</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="93-latency-and-throughput-of-atomic-operations-原子操作的延迟和吞吐量">9.3 Latency and throughput of atomic operations 原子操作的延迟和吞吐量</h2>
<p>在前几章我们了解到，对全局内存的访问很慢很慢，但只要有足够的线程，我们就可以通过零开销上下文切换技术来隐藏这一延迟，并将延迟转移到 DRAM 带宽。上述操作的前提都是<strong>有足够数量的线程并行访问内存</strong>。遗憾的是，当我们使用对全局内存进行原子操作时，线程对全局内存的读写操作转换为顺序操作，</p>
<p>🙋‍♀️🌰，对于具有 8 通道、64 比特数据位宽、频率为 1G、访问延迟为 200 个时钟周期的 DRAM，其峰值吞吐量为 8 byte* 2（每个周期传输两次）*1G*8 通道=128 GB/s。如果每个元素大小为 4 字节，那么每秒将能够读写 32G 个元素。</p>
<p>与之相反，每次具有一个读、一个写的原子操作的访问周期是 400 个时钟周期，那么每秒做多进行 2.5M 次原子操作。</p>
<p>当然，并非所有的原子操作都在对同一个位置进行修改，但即便数据均匀分布，那么理论上限为 2.5 M *7 = 17.5M。但在现实中，由于单词中的字母分布并不均匀，实际加速系数也达不到这么高。</p>
<p>增加原子操作吞吐量的一个手段是减少单词访存延迟，可以使用缓存进行优化。因此，原子操作支持对末级缓存进行操作，末级缓存由所有流多处理器共享。对末级缓存的访存时延相较 DRAM 少了一个数量级。</p>
<h2 id="94-privatization-私有化">9.4 Privatization 私有化</h2>
<p>私有化也是增加原子操作吞吐量的一个技术。私有化指的是线程将频繁访问的数据结构拷贝到私有内存中，计算结束后再合并到原数据结构中。</p>
<p>在直方图中，我们可以为每个 block 应用私有化，并在计算结束后将其合并。代码如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">histo_private_kernel</span><span class="p">(</span><span class="kt">char</span> <span class="o">*</span><span class="n">data</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">length</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                     <span class="kt">unsigned</span> <span class="kt">int</span> <span class="o">*</span><span class="n">histo</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&lt;</span> <span class="n">length</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">alphabet_position</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="sc">&#39;a&#39;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">alphabet_position</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">alphabet_position</span> <span class="o">&lt;</span> <span class="mi">26</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo</span><span class="p">[</span><span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">NUM_BINS</span> <span class="o">+</span> <span class="n">alphabet_position</span><span class="o">/</span><span class="mi">4</span><span class="p">]),</span> <span class="mi">1</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">bin</span><span class="o">=</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">bin</span><span class="o">&lt;</span><span class="n">NUM_BINS</span><span class="p">;</span> <span class="n">bin</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">binValue</span> <span class="o">=</span> <span class="n">histo</span><span class="p">[</span><span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">NUM_BINS</span> <span class="o">+</span> <span class="n">bin</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span><span class="p">(</span><span class="n">binValue</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo</span><span class="p">[</span><span class="n">bin</span><span class="p">]),</span> <span class="n">binValue</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>以块为单位进行私有化的好处是当我们需要进行同步时（合并前要确保使用同一块副本的线程都计算结束）可以直接调用 <code>syncthreads</code>。此外，如果直方图的长度够小，还可以在共享内存中声明副本。</p>
<h2 id="95-coarsening-粗化">9.5 Coarsening 粗化</h2>
<p>在 CPU 中，我们常常让粗化后的线程对数据进行连续访问，这是为了充分利用 CPU 的缓存机制。</p>
<p>在 GPU 中，由于内存合并访问技术，不应该让线程内部顺序访问连续数据，而是应该让一个线程束内线程单次连续访存。这种分区方式被称为交错分区 interleave partition</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">histo_private_kernel</span><span class="p">(</span><span class="kt">char</span><span class="o">*</span> <span class="n">data</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">length</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                     <span class="kt">unsigned</span> <span class="kt">int</span><span class="o">*</span> <span class="n">histo</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Initialize privatized bins
</span></span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">histo_s</span><span class="p">[</span><span class="n">NUM_BINS</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">bin</span><span class="o">=</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">bin</span><span class="o">&lt;</span><span class="n">NUM_BINS</span><span class="p">;</span> <span class="n">bin</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">histo_s</span><span class="p">[</span><span class="n">binIdx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0u</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Histogram
</span></span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">tid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">tid</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">length</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">gridDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">alphabet_position</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="sc">&#39;a&#39;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">alphabet_position</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">alphabet_position</span> <span class="o">&lt;</span> <span class="mi">26</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo_s</span><span class="p">[</span><span class="n">alphabet_position</span><span class="o">/</span><span class="mi">4</span><span class="p">]),</span> <span class="mi">1</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Commit to global memory
</span></span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">bin</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">bin</span><span class="o">&lt;</span><span class="n">NUM_BINS</span><span class="p">;</span> <span class="n">bin</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">binValue</span> <span class="o">=</span> <span class="n">histo_s</span><span class="p">[</span><span class="n">binIdx</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">binValue</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo</span><span class="p">[</span><span class="n">binIdx</span><span class="p">]),</span> <span class="n">binValue</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="96-aggregation-聚合">9.6 Aggregation 聚合</h2>
<p>在数据中可能存在局部大量重复区域的情况，这种情况下可能导致线程一起对某个位置同时进行原子操作，为了避免这一情况，我们可以聚合这些局部重复结果，即使用一个变量记录当前的类别和该类别对应的数量，知道计算出不同的类别时才将上一个类别的数量添加到公用变量中。上述技术可以将给予大量重复区域的更新事务合并为一个事务，减少了公用变量的访存密度。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">histo_private_kernel</span><span class="p">(</span><span class="kt">char</span><span class="o">*</span> <span class="n">data</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">length</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                     <span class="kt">unsigned</span> <span class="kt">int</span><span class="o">*</span> <span class="n">histo</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1">// Initialize privatized bins
</span></span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">histo_s</span><span class="p">[</span><span class="n">NUM_BINS</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">bin</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">bin</span> <span class="o">&lt;</span> <span class="n">NUM_BINS</span><span class="p">;</span> <span class="n">bin</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="n">histo_s</span><span class="p">[</span><span class="n">bin</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0u</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Histogram
</span></span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">accumulator</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">prevBinIdx</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">tid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">tid</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">length</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">gridDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">alphabet_position</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="sc">&#39;a&#39;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">alphabet_position</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">alphabet_position</span> <span class="o">&lt;</span> <span class="mi">26</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="kt">int</span> <span class="n">bin</span> <span class="o">=</span> <span class="n">alphabet_position</span><span class="o">/</span><span class="mi">4</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span><span class="p">(</span><span class="n">bin</span> <span class="o">==</span> <span class="n">prevBinIdx</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="o">++</span><span class="n">accumulator</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span><span class="p">(</span><span class="n">accumulator</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo_s</span><span class="p">[</span><span class="n">prevBinIdx</span><span class="p">]),</span> <span class="n">accumulator</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">                <span class="n">accumulator</span> <span class="o">=</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                <span class="n">prevBinIdx</span> <span class="o">=</span> <span class="n">bin</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">accumulator</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo_s</span><span class="p">[</span><span class="n">prevBinIdx</span><span class="p">]),</span> <span class="n">accumulator</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// Commit to global memory
</span></span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">bin</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">bin</span><span class="o">&lt;</span><span class="n">NUM_BINS</span><span class="p">;</span> <span class="n">bin</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">binValue</span> <span class="o">=</span> <span class="n">histo_s</span><span class="p">[</span><span class="n">bin</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">binValue</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nf">atomicAdd</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">histo</span><span class="p">[</span><span class="n">bin</span><span class="p">]),</span> <span class="n">binValue</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="chapter-10-reduction-归约">Chapter 10: Reduction 归约</h1>
<p>本章介绍了对于归约核函数的一系列优化技术，使用了包括最小化控制流分歧、最小化内存访问分歧、最少全局内存访问、线程粗化等技术。</p>
<h2 id="101-background-背景">10.1 Background 背景</h2>
<p>对归约的介绍略。其中介绍了一个术语 identity value，GPT 将其翻译为单位值，类似于单位元，即在归约运算中（归约都是二元运算），某个数与单位值进行归约操作，结果仍是该数。</p>
<h2 id="102-reduction-tree-归约树">10.2 Reduction tree 归约树</h2>
<p>以 max 算子为例，其归约过程可以用如下一棵归约树进行描述：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410191415104.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>如果想要用上图所示归约树的过程进行归约，这要求归约算子具备结合律。此外，下一节使用的优化技术还要求归约算子具备交换律。</p>
<h2 id="103-a-simple-reduction-kernel-一个简单的归约核函数">10.3 A simple reduction kernel 一个简单的归约核函数</h2>
<p>由于在归约过程中不同的线程之间需要进行数据交互，我们首先从单一 block 开始。由于一个 block 中最多有 1024 个线程，因此我们最多能够处理 2048 个元素。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410191436804.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>求和归约核函数如上所述，<code>i</code> 表示当前线程被分配到的元素下标，只有对应这个下标的线程才能够对这个元素进行写入。使用 <code>stride</code> 表示第 <code>i</code> 个元素归约操作的另一个参数的距离。在第一轮中，<code>stride</code> 为 1，第 <code>i</code> 个元素（即所有偶数位索引）需要与 <code>i+1</code> 的元素进行相加操作；第二轮中，<code>stride</code> 为 2，只有 <code>i</code> 为 4 的整数倍的元素需要与 <code>i+2</code> 进行相加操作；&hellip;；第 <code>n</code> 轮中，<code>stride</code> 为 2^(n-1)，只有 <code>i</code> 为 2^n 的整数倍，即 <code>threadIdx.x</code> 为 <code>stride</code> 的整数倍才需要进行归约操作，另一个归约元素为 <code>i+stride</code>。归约树如下所示：</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410191537773.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="104-minimizing-control-divergence-最小化控制流分歧">10.4 Minimizing control divergence 最小化控制流分歧</h2>
<p>上一节的实现代码具有严重的控制流分歧，在后几轮迭代中，只有 2 的幂的整数倍的线程才会被激活。控制流分歧会导致低硬件资源利用率。控制分歧的思路是在每轮迭代中，尽可能将被激活的线程集中在一起，如下图所示：</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410191553796.png?x-oss-process=image/quality,q_90/format,webp"><br>
不难发现，激活线程数量每次都是上一次的一半（向上取整），假设本轮迭代激活了 n 个线程，那么这其中第 <code>i</code> 个线程归约运算的两个元素的下标分别为 <code>i</code> 和 <code>n+i</code>。基于此规律，我们可以写出最小化控制流分歧版本的归约核函数：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">ConvergentSumReductionKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">input</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">output</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">stride</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">stride</span> <span class="o">/=</span> <span class="mi">2</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">stride</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">input</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">stride</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="o">*</span><span class="n">output</span> <span class="o">=</span> <span class="n">input</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="105-minimizing-memory-divergence-最小化内存分歧">10.5 Minimizing memory divergence 最小化内存分歧</h2>
<p>在 10.3 节中给出的代码还有一个缺陷是内存访问分歧，无法启用合并内存访问技术。而上一节中我们无意中修复了这个问题，所有活动线程在每轮迭代中都是连续访问内存的。</p>
<h2 id="106-minimizing-global-memory-accesses-最小化全局内存访问">10.6 Minimizing global memory accesses 最小化全局内存访问</h2>
<p>通过共享内存可以避免频繁访问全局内存，代码也挺简单：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">SharedMemorySumReductionKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">input</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">input_s</span><span class="p">[</span><span class="n">BLOCK_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">t</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">input_s</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">input</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="n">input</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="n">BLOCK_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span> <span class="n">stride</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">stride</span> <span class="o">/=</span> <span class="mi">2</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">stride</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">input_s</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">+=</span> <span class="n">input_s</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="n">stride</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="o">*</span><span class="n">output</span> <span class="o">=</span> <span class="n">input_s</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="107-hierarchical-reduction-for-arbitrary-input-length-任意输入长度的分层归约">10.7 Hierarchical reduction for arbitrary input length 任意输入长度的分层归约</h2>
<p>在此之前，我们假设输入的长度小于一个 block 内的线程数，这是由于我们仅能够对一个 block 内的线程进行同步。当输入长度在一个 block 内放不下时，就需要将其划分到多个 block。由于缺乏 block 间的同步机制，我们选择在每个 block 内部独立进行归约，并将结果通过原子操作归约到全局结果中。如下图所示：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410200849301.png?x-oss-process=image/quality,q_90/format,webp"><br>
相应代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="nf">SegmentedSumReductionKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">input</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">output</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">input_s</span><span class="p">[</span><span class="n">BLOCK_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">segment</span> <span class="o">=</span> <span class="mi">2</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">segment</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">t</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">input_s</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">input</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">BLOCK_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span> <span class="n">stride</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">stride</span> <span class="o">/=</span> <span class="mi">2</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">t</span> <span class="o">&lt;</span> <span class="n">stride</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">input_s</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">+=</span> <span class="n">input_s</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="n">stride</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">t</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">atomicAdd</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">input_s</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="108-thread-coarsening-for-reduced-overhead-线程粗化">10.8 Thread coarsening for reduced overhead 线程粗化</h2>
<p>之前的代码都是一个元素对应一个线程，并采取多个 block。如果硬件资源不够，那么这些线程和 block 将以线程束为单位，并且<strong>顺序</strong>执行。前面提到，在每个 block 中，随着归约迭代的进行，一个 block 中的线程很多将空闲下来，而在迭代后期，每个线程束中的线程中的控制流分歧效应愈发显著。如果所有 block 都是并行执行的，那么上述开销难以避免。但如果这些执行块是顺序执行的，则完全没有必要。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="nf">CoarsenedSumReductionKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">input</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">output</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">input_s</span><span class="p">[</span><span class="n">BLOCK_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">segment</span> <span class="o">=</span> <span class="n">COARSE_FACTOR</span><span class="o">*</span><span class="mi">2</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">segment</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">t</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">sum</span> <span class="o">=</span> <span class="n">input</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">tile</span> <span class="o">=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">tile</span> <span class="o">&lt;</span> <span class="n">COARSE_FACTOR</span><span class="o">*</span><span class="mi">2</span><span class="p">;</span> <span class="o">++</span><span class="n">tile</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">sum</span> <span class="o">+=</span> <span class="n">input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">tile</span><span class="o">*</span><span class="n">BLOCK_DIM</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">input_s</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">sum</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span> <span class="n">stride</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">stride</span> <span class="o">/=</span> <span class="mi">2</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="n">t</span> <span class="o">&lt;</span> <span class="n">stride</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">input_s</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">+=</span> <span class="n">input_s</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="n">stride</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">t</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">atomicAdd</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">input_s</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>代码如上所示，每个元素在初始迭代中负责将 <code>2*COARSE_FACTOR</code> 个元素而非 2 两个元素相加，其余大致相同。</p>
<h1 id="chapter-11-prefix-sum-scan-前缀和">Chapter 11: Prefix sum (scan) 前缀和</h1>
<h2 id="111-background-背景">11.1 Background 背景</h2>
<p>数学上 inclusive scan 的定义为输入一个具有结合律的算子 $\oplus$ 和一个长度为 n 的向量，输出为：</p>


<div>$$

[x_0, x_0\oplus x_1,...,(x_0\oplus x_1\oplus...\oplus x_{n-1})]

$$</div>

<p>上述公式被称为 inclusive 是由于输出元素中包含了对应位置的输入元素，与之相反的是 exclusive scan，其输出表示为：</p>


<div>$$

[i, x_0, x_0\oplus x_1,...,(x_0\oplus x_1\oplus...\oplus x_{n-2})]

$$</div>

<p>显然二者可以轻易进行转换，因此本章将以 inclusive scan 进行编程。</p>
<h2 id="112-parallel-scan-with-the-kogge-stone-algorithm-使用-kogge-stone-算法的并行扫描">11.2 Parallel scan with the Kogge-Stone algorithm 使用 Kogge-Stone 算法的并行扫描</h2>
<p>在并行算法中，如果我们让每个线程负责一个元素计算，其并不会比顺序算法更快，这是由于计算最后一个元素的线程仍需要进行完整的前缀和计算，相当于跑了一次顺序算法。如果硬件资源不足以支撑所有线程并发执行，那么并行算法将比顺序算法更慢，时间复杂度达到 $O(n^2)$</p>
<p>想要提高并行计算速率，就必须在线程之间共享中间结果。这里介绍 Kogge-Stone 算法。算法示意图如下所示：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410221112153.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>上图展示了一种就地算法，即图中的 $y_i$ 和 $x_i$ 表示的是同一个位置。上述代码的输入为数组 <code>XY</code>，在迭代开始时，<code>XY</code> 中包记录了所有 <code>x[i]</code>，在进行 <code>k</code> 轮迭代之后，<code>XY[i]</code> 中记录了 <code>x[i]</code> 即其往前最多共 <code>2^k</code> 个元素的和。例如，在进行了 2 轮迭代后，<code>XY[i] = x[i]+x[i-1]+x[i-2]+x[i-3]</code>。</p>
<p>整个算法流程为：在迭代开始时，<code>XY[i] = x[i]</code>；在第一轮迭代中，<code>XY[0]</code> 已经符合要求，不需要计算，除第 0 位以为所有元素都加上其前一个元素，即 <code>XY[i] = XY[i]+XY[i-1]</code>；在第二轮迭代中，<code>XY[0,1]</code> 均已符合要求，不需要计算，除第 0 和 1 位元素外所有元素都加上与其距离为 2 的元素，即 <code>XY[i]=XY[i]+XY[i-2]</code>；在第 <code>k</code> 轮迭代中，前 <code>2^(k-1)</code> 个元素已经符合结果，除此以外的所有元素都加上与其距离为 <code>2^(k-1)</code> 的元素，即 <code>XY[i] = XY[i]+XY[i-2^(k-1)]</code>。相应核函数实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">Kogge_Stone_scan_kernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">X</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">Y</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">XY</span><span class="p">[</span><span class="n">SECTION_SIZE</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">	<span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="o">+</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">i</span><span class="o">&lt;</span><span class="n">N</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">XY</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">		<span class="n">XY</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0f</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">for</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">int</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">;</span> <span class="n">stride</span> <span class="o">&lt;</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">stride</span><span class="o">*=</span><span class="mi">2</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">		<span class="kt">float</span> <span class="n">temp</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span><span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&gt;=</span> <span class="n">stride</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">temp</span> <span class="o">=</span> <span class="n">XY</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">+</span> <span class="n">XY</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="o">-</span><span class="n">stride</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">		<span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span><span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&gt;=</span> <span class="n">stride</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">XY</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">temp</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">i</span><span class="o">&lt;</span><span class="n">N</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">Y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">XY</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>不难发现在公式 <code>XY[i] = XY[i]+XY[i-stride]</code> 中具备明显的条件竞争，在上述代码使用一个中间变量 <code>temp</code> 暂存计算结果，使用同步确保所有线程都计算完毕后再进行写入操作。</p>
<h2 id="113-speed-and-work-efficiency-consideration-速度和任务效率考量">11.3 Speed and work efficiency consideration 速度和任务效率考量</h2>
<p>并行算法的一个性能指标是任务效率 work efficiency，其表示并行算法的计算量与理论最小计算量的逼近程度。例如，求前缀和最少需要 N-1 即 O(n) 次加法。</p>
<p>Kogge-Stone 算法的任务效率计算公式为：最多有 $\log_2 N$ 轮迭代，每轮迭代需要计算 $N-\text{stride}$ 次加法，总计为：</p>


<div>$$

\text{work efficiency} = \sum_{\text{stride}} (N-\text{stride}),\ \ \text{for stride}=1,2,4,...,N//2

$$</div>

<p>第一项与求和变量 <code>stride</code> 无关，求和为 $N\log_2 N$，第二项为等差数列求和，近似于 $N-1$，最终计算效率与二者之和，即 $N\log_2 N - (N-1)$。</p>
<p>好消息是其性能比起 O(n^2) 的朴素算法要好，坏消息是比不上顺序算法 O(n)。尽管其计算操作数比顺序算法多，但是其仅需要 $log_2 N$ 轮迭代即可计算结束，而在顺序算法中需要 $N$ 轮迭代才能算完。在实际中，由于线程束控制流分歧的存在，并不能达到我们的理论任务效率，实际中的任务效率约为 $N\log_2 N$。</p>
<p>我们可以使用迭代次数来比较不同的并行算法，但是迭代次数少的算法并不一定运行速度就快，由于算法消耗的资源和具体硬件资源的限制，可能出现单轮迭代无法并行执行的情况。</p>
<p>Kogge-Stone 算法的缺点是在硬件资源受限的情况下其执行效率很低，并且由于加法次数仍未优化到最低，这些额外的加法也会带来功耗开销。而其优点是，在硬件资源充足的情况下性能很高。</p>
<h2 id="114-parallel-scan-with-the-brent-kung-algorithm-使用-brent-kung-算法的并行扫描">11.4 Parallel scan with the Brent-Kung algorithm 使用 Brent-Kung 算法的并行扫描</h2>
<p>本文停更。</p>
]]></content:encoded>
    </item>
    <item>
      <title>2d 卷积梯度推导与实现</title>
      <link>https://www.zhouxin.space/notes/2d-convolution-gradient-derivation-and-implementation/</link>
      <pubDate>Wed, 11 Sep 2024 16:04:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/2d-convolution-gradient-derivation-and-implementation/</guid>
      <description>&lt;h1 id=&#34;符号说明&#34;&gt;符号说明&lt;/h1&gt;


&lt;div&gt;$$

\begin{align*}
X &amp;amp;: 卷积输入，\text{shape} 为[b,h,w,c_{in}]\\
W &amp;amp;: 卷积核，\text{shape}为[a,a,c_{in},c_{out}]\\
s &amp;amp;: 步长\\
f &amp;amp;: 卷积结果，\text{shape}为[b,(h-k)/s&amp;#43;1,(w-k)/s&amp;#43;1,c_{out}]\\
loss &amp;amp;: 损失函数，loss = g(f)
\end{align*}

$$&lt;/div&gt;

&lt;p&gt;约定，所有张量下标从 0 开始。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="符号说明">符号说明</h1>


<div>$$

\begin{align*}
X &amp;: 卷积输入，\text{shape} 为[b,h,w,c_{in}]\\
W &amp;: 卷积核，\text{shape}为[a,a,c_{in},c_{out}]\\
s &amp;: 步长\\
f &amp;: 卷积结果，\text{shape}为[b,(h-k)/s&#43;1,(w-k)/s&#43;1,c_{out}]\\
loss &amp;: 损失函数，loss = g(f)
\end{align*}

$$</div>

<p>约定，所有张量下标从 0 开始。</p>
<h1 id="卷积运算">卷积运算</h1>
<p>对于结果矩阵中 f[i,j,k,l]，其卷积的范围（感受野）为：</p>


<div>$$

X[i,js:js&#43;a,ks:ks&#43;a,:]

$$</div>

<p>那么卷积运算就可以表示为：</p>


<div>$$

\begin{align*}
f[i,j,k,l] &amp;= \sum_{m=0}^{a-1} \sum_{n=0}^{a-1} \sum_{p=0}^{c_{in}-1}(X[i,m&#43;js,n&#43;ks,p]\cdot w[m,n,p,l])\\
&amp;=\vec{x_{vec}}^T  \vec{w_{vec}}
\end{align*}

$$</div>

<p>通过 im2col 技术，可以将卷积运算转换为向量内积。</p>
<h2 id="损失函数对-w-的梯度">损失函数对 W 的梯度</h2>
<p>前式中，f[i,j,k,l] 对于 w[m,n,p,l] 的梯度贡献只有一项 x[i,m+js,n+ks,p]。我们需要确保 x 的索引有效，因此有如下约束条件：</p>


<div>$$

\begin{cases}
0\leq i &lt; b-1\\
0\leq m&#43;js &lt; h\\
0\leq n&#43;ks &lt; w \\
0\leq p &lt;c_{in}
\end{cases}

$$</div>

<p>化简得到符合条件的 ijkl 的约束为：</p>


<div>$$

\begin{cases}
0\leq i &lt; b-1\\
j&lt;(h-m)/s\\
k&lt;(w-n)/s
\end{cases}

$$</div>

<p>根据链式法则，有：</p>


<div>$$

\begin{align*}
\frac{\partial  loss}{\partial w[m,n,p,l]} 
&amp;= \sum_{i=0}^{b-1}\sum_{j=0}^{\lfloor{(h-m)/s-1\rfloor}}\sum_{k=0}^{\lfloor{(w-n)/s-1\rfloor}} \frac{\partial loss}{\partial f[i,j,k,l]}\frac{\partial f[i,j,k,l]}{\partial w[m,n,p,l]}\\
&amp;=\sum_{i=0}^{b-1}\sum_{j=0}^{\lfloor{(h-m)/s-1\rfloor}}\sum_{k=0}^{\lfloor{(w-n)/s-1\rfloor}} \frac{\partial loss}{\partial f[i,j,k,l]} X[i, m&#43;js, n&#43;ks, p]
\end{align*}

$$</div>

<p>其中 $\partial{loss} /\partial f$ 在反向传播时已经得到了，且 $\partial{loss} /(\partial {f[i,j,k,l]})$ 等于 $(\partial{loss} /\partial {f})[i,j,k,l]$，将 $\partial{loss} /\partial f$ 记为 outgrad。</p>
<p>观察上式，其和我们之前推导的卷积表达式非常像：后两个求和项的索引为 j,k 与结果索引无关，说明其在这两个维度上进行了卷积操作，第一个索引 l 与结果索引有关，说明这是一个向量内积。具体来，这个表达式可以视为卷积操作，卷积核为 loss 对 w 的导数，被卷积对象为 X，batch 的维度在最后一个，做内积的维度在第一个。</p>
<p>对比二式，卷积核为 autograd，卷积的单个感受野内部存在空洞，长宽方向上两个像素之间均隔了 s-1 个长度。这是一种空洞卷积，如下图所示，红色为卷积位置。</p>


<div>$$

\left[ \begin{matrix}
	{\color[RGB]{240, 0, 0} 1}&amp;		2&amp;		{\color[RGB]{240, 0, 0} 3}&amp;		4\\
	5&amp;		6&amp;		7&amp;		8\\
	{\color[RGB]{240, 0, 0} 9}&amp;		10&amp;		{\color[RGB]{240, 0, 0} 11}&amp;		12\\
	13&amp;		14&amp;		15&amp;		16\\
\end{matrix} \right]

$$</div>

<p>怎么实现这个空洞卷积呢？我们可以扩张我们的卷积核 outgrad，即在每一行没一列上都 dilate 填充 s-1 个元素，将 2×2 的的卷积核心扩展成 4×4 的卷积和，按照步长为 1 进行卷积：</p>


<div>$$

\left[ \begin{matrix}
	w_1&amp;		w_2\\
	w_3&amp;		w_4\\
\end{matrix} \right] \,\,\Longrightarrow \left[ \begin{matrix}
	w_1&amp;		0&amp;		w_2&amp;		0\\
	0&amp;		0&amp;		0&amp;		0\\
	w_3&amp;		0&amp;		w_4&amp;		0\\
	0&amp;		0&amp;		0&amp;		0\\
\end{matrix} \right]

$$</div>

<p>到这里，我们的损失函数对权重的梯度表达式就可以写出来了：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">X</span> <span class="c1"># 输入 [b, h, w, c_in]</span>
</span></span><span class="line"><span class="cl"><span class="n">W</span> <span class="c1"># 卷积核 [a, a, w_in, w_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad</span> <span class="c1"># loss对输出的梯度</span>
</span></span><span class="line"><span class="cl"><span class="n">stride</span> <span class="c1"># 卷积步长</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">outgrad_dilated</span> <span class="o">=</span> <span class="n">dilate</span><span class="p">(</span><span class="n">outgrad</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [b, *, *, c_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad_dilated_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">outgrad_dilated</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># [*, *, b, cout]</span>
</span></span><span class="line"><span class="cl"><span class="n">X_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="c1"># [c_in, h, w, b]</span>
</span></span><span class="line"><span class="cl"><span class="n">W_grad_</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">X_permuted</span><span class="p">,</span> <span class="n">outgrad_dilated_permuted</span><span class="p">)</span> <span class="c1">#[c_in, h, w, c_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">W_grad</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">W_grad_</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>对于 padding 不为 1 的情况，我们直接从 shape 来考虑。在正向过程中，可以直接假定 padding 为 0，输入为 pad 后新的输入。根据这一等价转换，<code>conv(X_permuted, outgrad_dilated_permuted)</code> 这一步得到中 X_permuted 是根据真实的 X 得到，而 outgrad 是等价的 X 得到的，作为卷积核的 outgrad 其偏大了 2padding，因此在卷积这一步中要指定 padding=2padding：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">X</span> <span class="c1"># 输入 [b, h, w, c_in]</span>
</span></span><span class="line"><span class="cl"><span class="n">W</span> <span class="c1"># 卷积核 [a, a, w_in, w_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad</span> <span class="c1"># loss对输出的梯度</span>
</span></span><span class="line"><span class="cl"><span class="n">stride</span> <span class="c1"># 卷积步长</span>
</span></span><span class="line"><span class="cl"><span class="n">padding</span> <span class="c1"># </span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">outgrad_dilated</span> <span class="o">=</span> <span class="n">dilate</span><span class="p">(</span><span class="n">outgrad</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [b, *, *, c_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad_dilated_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">outgrad_dilated</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># [*, *, b, cout]</span>
</span></span><span class="line"><span class="cl"><span class="n">X_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="c1"># [c_in, h, w, b]</span>
</span></span><span class="line"><span class="cl"><span class="n">W_grad_</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">X_permuted</span><span class="p">,</span> <span class="n">outgrad_dilated_permuted</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">2</span><span class="o">*</span><span class="n">padding</span><span class="p">)</span> <span class="c1">#[c_in, h, w, c_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">W_grad</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">W_grad_</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="损失函数对-x-的梯度">损失函数对 X 的梯度</h2>
<p>有了上面的基础，我们来讨论 loss 对 X 的梯度。首先来讨论一点，对于 X[i,j,k,l]，如果其与 w[m,n,l,p] 相乘了，那么其应该在计算卷积 f[i,(j-m)/s,(k-n)/s,p] 的结果，即：</p>


<div>$$

f[i,(j-m)/s,(k-n)/s,p] = \sum_{p=0}^{c_{out}-1}w[m,n,l,p]\cdot X[i,j,k,l]

$$</div>

<p>那么 loss 对于 X[i,j,k,l] 的梯度，只有 f[i,(j-m)/s,(k-n)/s,p] 对其有贡献，且贡献为 w[m,n,l,p]。</p>
<p>接下来可以推导 loss 对于 X[i,j,k,l] 的表达式：</p>


<div>$$

\begin{align*}
\frac{\partial loss}{\partial X\left[ i,j,k,l \right]}
&amp;=\sum_{m=0}^{a-1}{\sum_{n=0}^{a-1}{\sum_{p=0}^{c_{out}}{\frac{\partial loss}{\partial f[i,(j-m)/s,(k-n)/s,p]}\cdot \frac{\partial f[i,(j-m)/s,(k-n)/s,p]}{\partial X\left[ i,j,k,l \right]}}}}
\\
&amp;=\sum_{m=0}^{a-1}{\sum_{n=0}^{a-1}{\sum_{p=0}^{c_{out}}{\frac{\partial loss}{\partial f[i,(j-m)/s,(k-n)/s,p]}w\left[ m,n,l,p \right]}}}
\end{align*}

$$</div>

<p>又是似曾相识的一幕，有了上面的经验，这次分析就游刃有余得多：卷积核是 W，被卷积对象是 autograd，在 autograd 的最后一个维度上进行线性变换，将其从 c_out 映射到 c_in 上。batch 的维度是 W 的第一个维度。在长宽两个维度上，感受野内部每次的步长是 -1，也就是说卷积核第一个元素将与最后一个元素相乘。我们将卷积核 flip 一下即可。聪明的你肯定注意到了，感受野内部不是连续的，两个元素之间间隔了 s-1 个元素，因此也需要将 outgrad 使用 dilate 填充 s-1 个 0 元素。</p>
<p>可达鸭眉头一皱，事情没有这么简单。理论上，这个梯度的 shape 应当与 X 相等，但 outgrad 本来就比 X 小，经过卷积之后应该更小了。怎会如此？我们直接观察 j=0、k=0 的状态，代入上式，会发现我们对 outgrad 的索引为负值了。这时候就需要将 outgrad 周围填充 a-1 个元素。</p>
<p>到这里，我们的损失函数对输入的梯度表达式就可以写出来了：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">X</span> <span class="c1"># 输入 [b, h, w, c_in]</span>
</span></span><span class="line"><span class="cl"><span class="n">W</span> <span class="c1"># 卷积核 [a, a, w_in, w_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad</span> <span class="c1"># loss对输出的梯度</span>
</span></span><span class="line"><span class="cl"><span class="n">stride</span> <span class="c1"># 卷积步长</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">W_flipped</span> <span class="o">=</span> <span class="n">flip</span><span class="p">(</span><span class="n">W</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># 在前两个维度上翻转stride</span>
</span></span><span class="line"><span class="cl"><span class="n">W_flipped_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">W_flipped</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span> <span class="c1"># [a, a, w_out, w_in]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad_dilated</span> <span class="o">=</span> <span class="n">dilate</span><span class="p">(</span><span class="n">outgrad</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># dilate填充stride-1个0</span>
</span></span><span class="line"><span class="cl"><span class="n">W_grad</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">outgrad_dilated</span><span class="p">,</span> <span class="n">W_flipped_permuted</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">a</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [b, h, w, c_in]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>对于 padding 不为 1 的情况，我们一样从 shape 来考虑。<code>conv(outgrad_dilated, W_flipped_permuted, padding=a-1)</code> 这一句中 outgrad 偏大 2padding，W 无偏，因此 padding 数要少一倍的 padding：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">X</span> <span class="c1"># 输入 [b, h, w, c_in]</span>
</span></span><span class="line"><span class="cl"><span class="n">W</span> <span class="c1"># 卷积核 [a, a, w_in, w_out]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad</span> <span class="c1"># loss对输出的梯度</span>
</span></span><span class="line"><span class="cl"><span class="n">stride</span> <span class="c1"># 卷积步长</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">W_flipped</span> <span class="o">=</span> <span class="n">flip</span><span class="p">(</span><span class="n">W</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># 在前两个维度上翻转stride</span>
</span></span><span class="line"><span class="cl"><span class="n">W_flipped_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">W_flipped</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span> <span class="c1"># [a, a, w_out, w_in]</span>
</span></span><span class="line"><span class="cl"><span class="n">outgrad_dilated</span> <span class="o">=</span> <span class="n">dilate</span><span class="p">(</span><span class="n">outgrad</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># dilate填充stride-1个0</span>
</span></span><span class="line"><span class="cl"><span class="n">W_grad</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">outgrad_dilated</span><span class="p">,</span> <span class="n">W_flipped_permuted</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">a</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">padding</span><span class="p">)</span> <span class="c1"># [b, h, w, c_in]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="参考文档">参考文档</h1>
<p><a href="https://johnwlambert.github.io/conv-backprop/">Backpropagation through a Conv Layer</a></p>
]]></content:encoded>
    </item>
    <item>
      <title>常用软件换源和代理配置方法</title>
      <link>https://www.zhouxin.space/notes/dev-tools-source-and-proxy-configuration/</link>
      <pubDate>Sat, 31 Aug 2024 14:23:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/dev-tools-source-and-proxy-configuration/</guid>
      <description>&lt;p&gt;本文记录了 Windows 和 Linux 平台上常用软件和开发工具的代理配置方法，镜像源优先使用中科大源，代理默认本地代理，端口号为 7890。&lt;/p&gt;
&lt;h1 id=&#34;windows&#34;&gt;Windows&lt;/h1&gt;
&lt;h2 id=&#34;winget&#34;&gt;Winget&lt;/h2&gt;
&lt;p&gt;winget 使用中科大镜像 &lt;sup id=&#34;fnref:1&#34;&gt;&lt;a href=&#34;#fn:1&#34; class=&#34;footnote-ref&#34; role=&#34;doc-noteref&#34;&gt;1&lt;/a&gt;&lt;/sup&gt;，需要管理员权限：&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;div class=&#34;chroma&#34;&gt;
&lt;table class=&#34;lntable&#34;&gt;&lt;tr&gt;&lt;td class=&#34;lntd&#34;&gt;
&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code&gt;&lt;span class=&#34;lnt&#34;&gt;1
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class=&#34;lntd&#34;&gt;
&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-bash&#34; data-lang=&#34;bash&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;winget &lt;span class=&#34;nb&#34;&gt;source&lt;/span&gt; remove winget
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;winget &lt;span class=&#34;nb&#34;&gt;source&lt;/span&gt; add winget https://mirrors.ustc.edu.cn/winget-source --trust-level trusted
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h2 id=&#34;git&#34;&gt;git&lt;/h2&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;div class=&#34;chroma&#34;&gt;
&lt;table class=&#34;lntable&#34;&gt;&lt;tr&gt;&lt;td class=&#34;lntd&#34;&gt;
&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code&gt;&lt;span class=&#34;lnt&#34;&gt;1
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class=&#34;lntd&#34;&gt;
&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-bash&#34; data-lang=&#34;bash&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;git config --global https.proxy http://127.0.0.1:7890
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;git config --global https.proxy http://127.0.0.1:7890
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h2 id=&#34;wsl&#34;&gt;WSL&lt;/h2&gt;
&lt;p&gt;wsl 的配置文件默认路径为 &lt;code&gt;%userprofile/.wslconfig&lt;/code&gt;，修改该文件为以下内容，就可以在 wsl 中使用 Windows 代理。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>本文记录了 Windows 和 Linux 平台上常用软件和开发工具的代理配置方法，镜像源优先使用中科大源，代理默认本地代理，端口号为 7890。</p>
<h1 id="windows">Windows</h1>
<h2 id="winget">Winget</h2>
<p>winget 使用中科大镜像 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>，需要管理员权限：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">winget <span class="nb">source</span> remove winget
</span></span><span class="line"><span class="cl">winget <span class="nb">source</span> add winget https://mirrors.ustc.edu.cn/winget-source --trust-level trusted
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="git">git</h2>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">git config --global https.proxy http://127.0.0.1:7890
</span></span><span class="line"><span class="cl">git config --global https.proxy http://127.0.0.1:7890
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="wsl">WSL</h2>
<p>wsl 的配置文件默认路径为 <code>%userprofile/.wslconfig</code>，修改该文件为以下内容，就可以在 wsl 中使用 Windows 代理。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">[</span><span class="err">experimental</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="err">networkingMode=mirrored</span>
</span></span><span class="line"><span class="cl"><span class="err">autoProxy=</span><span class="kc">true</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="anaconda-miniconda">Anaconda/ Miniconda</h2>
<h1 id="linux">Linux</h1>
<h2 id="git-1">git</h2>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">git config --global https.proxy http://127.0.0.1:7890
</span></span><span class="line"><span class="cl">git config --global https.proxy http://127.0.0.1:7890
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="apt">apt</h2>
<p><strong>修改镜像源：</strong></p>
<p>自 <code>Ubuntu 24.04</code> 起默认预装的系统中 APT 的系统源配置文件不再是传统的 <code>/etc/apt/sources.list</code>，而是使用新的 DEB822 格式，存储在 <code>/etc/apt/sources.list.d/ubuntu.sources</code>，修改该文件（需要 sudo）为以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-fallback" data-lang="fallback"><span class="line"><span class="cl">Types: deb
</span></span><span class="line"><span class="cl">URIs: https://mirrors.ustc.edu.cn/ubuntu
</span></span><span class="line"><span class="cl">Suites: noble noble-updates noble-backports
</span></span><span class="line"><span class="cl">Components: main restricted universe multiverse
</span></span><span class="line"><span class="cl">Signed-By: /usr/share/keyrings/ubuntu-archive-keyring.gpg
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Types: deb
</span></span><span class="line"><span class="cl">URIs: https://mirrors.ustc.edu.cn/ubuntu
</span></span><span class="line"><span class="cl">Suites: noble-security
</span></span><span class="line"><span class="cl">Components: main restricted universe multiverse
</span></span><span class="line"><span class="cl">Signed-By: /usr/share/keyrings/ubuntu-archive-keyring.gpg
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后执行 <code>sudo apt update</code> 以更新索引。</p>
<p><strong>设置代理：</strong><br>
向 <code>/etc/apt/apt.conf</code> 文件中添加如下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-fallback" data-lang="fallback"><span class="line"><span class="cl">Acquire::http::Proxy &#34;http://127.0.0.1:7890&#34;;
</span></span><span class="line"><span class="cl">Acquire::https::Proxy &#34;http://127.0.0.1:7890&#34;;  
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://mirrors.ustc.edu.cn/help/winget-source.html">WinGet - USTC Mirror Help</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>在VSCode中对CUDA和Python代码进行联合调试</title>
      <link>https://www.zhouxin.space/notes/joint-debgugging-of-cuda-and-python-in-vscode/</link>
      <pubDate>Sat, 24 Aug 2024 19:29:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/joint-debgugging-of-cuda-and-python-in-vscode/</guid>
      <description>&lt;p&gt;在 cmu10414 hw3 的最后实现矩阵乘法的算子的时候靠肉眼和 printf 实在是调不通，研究了一下怎么在 VSCode 中联合调试 CUDA 和 Python 代码，特此记录。&lt;/p&gt;
&lt;h1 id=&#34;项目准备&#34;&gt;项目准备&lt;/h1&gt;
&lt;p&gt;原项目中将 CUDA 代码编译为 so 动态链接库供 Python 调用，使用 cmake 进行构建。这里我们来构建一个最小样例进行调试。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>在 cmu10414 hw3 的最后实现矩阵乘法的算子的时候靠肉眼和 printf 实在是调不通，研究了一下怎么在 VSCode 中联合调试 CUDA 和 Python 代码，特此记录。</p>
<h1 id="项目准备">项目准备</h1>
<p>原项目中将 CUDA 代码编译为 so 动态链接库供 Python 调用，使用 cmake 进行构建。这里我们来构建一个最小样例进行调试。</p>
<p>整个项目的目录树为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">.
</span></span><span class="line"><span class="cl">├── CMakeLists.txt
</span></span><span class="line"><span class="cl">├── python
</span></span><span class="line"><span class="cl">│   └── test_cuda_hello.py
</span></span><span class="line"><span class="cl">└── src
</span></span><span class="line"><span class="cl">    ├── cuda_hello.cu
</span></span><span class="line"><span class="cl">    └── pybind_wrapper.cpp
</span></span></code></pre></td></tr></table>
</div>
</div><p>其中，<code>cuda_hello.cu</code> 是待调试的 CUDA 代码，里面定义了一个核函数和一个主机端调用接口：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&lt;stdio.h&gt;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">cuda_hello_kernel</span><span class="p">()</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">printf</span><span class="p">(</span><span class="s">&#34;Hello from CUDA kernel!</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">extern</span> <span class="s">&#34;C&#34;</span> <span class="kt">void</span> <span class="n">launch_cuda_hello</span><span class="p">()</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">cuda_hello_kernel</span><span class="o">&lt;&lt;&lt;</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">&gt;&gt;&gt;</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="n">cudaDeviceSynchronize</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>pybind_wrapper.cpp</code> 使用 pybind11 将这个函数注册到 Python 中：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&lt;pybind11/pybind11.h&gt;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">extern</span> <span class="s">&#34;C&#34;</span> <span class="kt">void</span> <span class="n">launch_cuda_hello</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">cuda_hello</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span><span class="s">&#34;hello&#34;</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">launch_cuda_hello</span><span class="p">,</span> <span class="s">&#34;A function that launches a CUDA kernel to print Hello&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在 <code>test_cuda_hello.py</code> 文件中，我们将通过动态链接库导入 <code>hello_cuda</code> 这个包，并调用其中的 <code>launch_cuda_hello</code> 函数：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">sys</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">os</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># 将 build 目录添加到 Python 路径</span>
</span></span><span class="line"><span class="cl"><span class="n">sys</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="vm">__file__</span><span class="p">),</span> <span class="s1">&#39;../build&#39;</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">cuda_hello</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">cuda_hello</span><span class="o">.</span><span class="n">hello</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意我们编译出的动态链接库文件在 <code>build</code> 目录下，因此要先将该目录添加到 Python 的搜索路径再导入。</p>
<p><code>CMakeLists.txt</code> 文件内容为，各代码含义见注释：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cmake" data-lang="cmake"><span class="line"><span class="cl"><span class="c"># 设置 CMake 的最低版本要求
</span></span></span><span class="line"><span class="cl"><span class="nb">cmake_minimum_required</span><span class="p">(</span><span class="s">VERSION</span> <span class="s">3.18</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 设置构建类型为 Debug
</span></span></span><span class="line"><span class="cl"><span class="nb">set</span><span class="p">(</span><span class="s">CMAKE_BUILD_TYPE</span> <span class="s">Debug</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 设置 CUDA 主机编译器为 g++
</span></span></span><span class="line"><span class="cl"><span class="nb">set</span><span class="p">(</span><span class="s">CMAKE_CUDA_HOST_COMPILER</span> <span class="s">/usr/bin/g++</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 定义项目名称和支持的语言
</span></span></span><span class="line"><span class="cl"><span class="nb">project</span><span class="p">(</span><span class="s">CudaHello</span> <span class="s">CUDA</span> <span class="s">CXX</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 设置 C++ 标准为 C++14
</span></span></span><span class="line"><span class="cl"><span class="nb">set</span><span class="p">(</span><span class="s">CMAKE_CXX_STANDARD</span> <span class="s">14</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 设置 CUDA 标准为 C++14
</span></span></span><span class="line"><span class="cl"><span class="nb">set</span><span class="p">(</span><span class="s">CMAKE_CUDA_STANDARD</span> <span class="s">14</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 启用 CUDA 语言支持
</span></span></span><span class="line"><span class="cl"><span class="nb">enable_language</span><span class="p">(</span><span class="s">CUDA</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 设置 CUDA 架构（根据 GPU 调整这个值）
</span></span></span><span class="line"><span class="cl"><span class="nb">set</span><span class="p">(</span><span class="s">CUDA_ARCHITECTURES</span> <span class="s">89</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 查找 Python 解释器和开发组件
</span></span></span><span class="line"><span class="cl"><span class="nb">find_package</span><span class="p">(</span><span class="s">Python</span> <span class="s">COMPONENTS</span> <span class="s">Interpreter</span> <span class="s">Development</span> <span class="s">REQUIRED</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 查找 pybind11 包
</span></span></span><span class="line"><span class="cl"><span class="nb">find_package</span><span class="p">(</span><span class="s">pybind11</span> <span class="s">CONFIG</span> <span class="s">REQUIRED</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 添加 CUDA 文件并创建共享库
</span></span></span><span class="line"><span class="cl"><span class="nb">add_library</span><span class="p">(</span><span class="s">cuda_functions</span> <span class="s">SHARED</span> <span class="s">src/cuda_hello.cu</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 设置目标属性，指定 CUDA 架构
</span></span></span><span class="line"><span class="cl"><span class="nb">set_target_properties</span><span class="p">(</span><span class="s">cuda_functions</span> <span class="s">PROPERTIES</span> <span class="s">CUDA_ARCHITECTURES</span> <span class="o">${</span><span class="nv">CUDA_ARCHITECTURES</span><span class="o">}</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 如果是 Debug 模式，为 CUDA 编译器添加调试选项
</span></span></span><span class="line"><span class="cl"><span class="nb">if</span><span class="p">(</span><span class="s">CMAKE_BUILD_TYPE</span> <span class="s">STREQUAL</span> <span class="s2">&#34;Debug&#34;</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl">    <span class="nb">target_compile_options</span><span class="p">(</span><span class="s">cuda_functions</span> <span class="s">PRIVATE</span> <span class="o">$&lt;</span><span class="nv">$&lt;COMPILE_LANGUAGE:CUDA</span><span class="o">&gt;</span><span class="s">:-G</span> <span class="s">-g&gt;</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="nb">endif</span><span class="p">()</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 创建 pybind11 模块
</span></span></span><span class="line"><span class="cl"><span class="nb">pybind11_add_module</span><span class="p">(</span><span class="s">cuda_hello</span> <span class="s">src/pybind_wrapper.cpp</span><span class="p">)</span><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="err">
</span></span></span><span class="line"><span class="cl"><span class="c"># 将 CUDA 函数库链接到 pybind11 模块
</span></span></span><span class="line"><span class="cl"><span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">cuda_hello</span> <span class="s">PRIVATE</span> <span class="s">cuda_functions</span><span class="p">)</span><span class="err">
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>有几个需要注意的点：<code>set(CUDA_ARCHITECTURES 89)</code> 显卡架构的参数应该根据自己显卡的型号的 CC 来填，各显卡 CC 值见 NVIDIA 官网：<a href="https://developer.nvidia.com/cuda-gpus">CUDA GPUs - Compute Capability | NVIDIA Developer</a>；<code>target_compile_options(cuda_functions PRIVATE $&lt;$&lt;COMPILE_LANGUAGE:CUDA&gt;:-G -g&gt;)</code> 用于在给 nvcc 指定编译参数 <code>-g -G</code>，确保其编译出的主机端和设备端代码都包含调试信息。</p>
<p>准备完以上文件，执行如下命令编译共享库：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">mkdir build
</span></span><span class="line"><span class="cl"><span class="nb">cd</span> build
</span></span><span class="line"><span class="cl">cmake ..
</span></span><span class="line"><span class="cl">make
</span></span></code></pre></td></tr></table>
</div>
</div><p>编译结束后，在 <code>build</code> 文件夹应该会出现一个文件名类似于 <code>cuda_hello.cpython-3x-x86_64-linux-gnu.so</code>（Windows 平台后缀为 <code>.pyd</code>）的共享库，说明编译成功。</p>
<p>然后执行 <code>test_cuda_hello.py</code> 文件，应该就能看到来自 GPU 的输出 <code>Hello from CUDA kernel!</code>。</p>
<p>万事俱备，接下来开始调试！</p>
<h1 id="手动调试">手动调试</h1>
<p>NVIDIA 提供了 cuda-gdb 工具对 cuda 代码进行调试，具体调试过程为：</p>
<ol>
<li>在终端输入 <code>cuda-gdb python --quite</code>，启动 cuda-gdb 调试器，对 Python 解释器进行调试；</li>
<li>在 cuda-gdb 交互终端中设置断点，例如 <code>break cuda_hello_kernel</code> 为 <code>cuda_hello_kernel</code> 函数设置断点，或者 <code>break src/cuda_hello.cu:4</code> 在 <code>cuda_hello.cu</code> 文件的第 4 行打上断点；</li>
<li>在交互终端输入 <code>run python/test_cuda_hello.py</code> 执行 Python 解释器，并将 py 文件作为参数传递给它。稍等一会，程序将在断点处停下，并提示：<code>CUDA thread hit Breakpoint 1, cuda_hello_kernel&lt;&lt;&lt;(1,1,1),(1,1,1)&gt;&gt;&gt; ()</code></li>
</ol>
<p>之后按照正常的 gdb 工具调试即可。</p>
<h1 id="配置-vscode-进行调试">配置 VSCode 进行调试</h1>
<p>前面已经实现了使用 cuda-gdb 工具进行调试，但我对 gdb 工具不太了解，只会使用基于 GUI 的调试工具。接下来我们配置 VSCode，使之支持对 CUDA 和 Python 代码联合调试。</p>
<p>首先安装插件 <a href="https://marketplace.visualstudio.com/items?itemName=NVIDIA.nsight-vscode-edition">Nsight Visual Studio Code Edition</a>，此插件由 NVIDIA 开发，用于在 VSCode 中支持对 CUDA 代码的调试 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>。</p>
<p>编辑 <code>.vscode/launch.json</code> 文件，输入如下内容，并修改其中 Python 解释器路径为正确值：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;version&#34;</span><span class="p">:</span> <span class="s2">&#34;0.2.0&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;configurations&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">        <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;Python: Launch&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;python&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;request&#34;</span><span class="p">:</span> <span class="s2">&#34;launch&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;program&#34;</span><span class="p">:</span> <span class="s2">&#34;${file}&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;console&#34;</span><span class="p">:</span> <span class="s2">&#34;integratedTerminal&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="p">},</span>
</span></span><span class="line"><span class="cl">        <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;CUDA GDB Server: Launch&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;cuda-gdb&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;request&#34;</span><span class="p">:</span> <span class="s2">&#34;launch&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;program&#34;</span><span class="p">:</span> <span class="s2">&#34;path/to/python&#34;</span><span class="p">,</span> <span class="c1">//修改为Python路径
</span></span></span><span class="line"><span class="cl">            <span class="nt">&#34;args&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;${file}&#34;</span><span class="p">],</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;debuggerPath&#34;</span><span class="p">:</span> <span class="s2">&#34;/usr/local/cuda/bin/cuda-gdb&#34;</span><span class="p">,</span> <span class="c1">// 确认cuda-gdb路径正确
</span></span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">],</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;compounds&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;Python and CUDA&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;configurations&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;Python: Launch&#34;</span><span class="p">,</span> <span class="s2">&#34;CUDA GDB Server: Launch&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上面这个文件由三部分组成，第一部分定义了 Python 调试器的相关配置，第二部分定义 cuda-gdb 调试器的配置，第三部分使用 compounds 将两个调试配置组装成一个，在调试时将同时启动这两个调试器。</p>
<p>接下来在 VSCode 中切换到 Run and Debug 面板，并修改调试配置为 <code>Python and CUDA</code>，如下图所示：</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408242150909.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>然后在 py 和 CUDA 文件中打上断点，在<strong>py 文件中</strong>按下快捷键 <code>F5</code> 开始调试，代码将在断点处停下：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408242155725.png?x-oss-process=image/quality,q_90/format,webp"><br>
继续运行，其将在 CUDA 断点处停下：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408242158812.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://docs.nvidia.com/nsight-visual-studio-code-edition/">NVIDIA Nsight VSCE Documentation</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>Programming Massively Parallel Processors A Hands-on Approach 4th Edition 学习笔记 Part 1</title>
      <link>https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-1/</link>
      <pubDate>Mon, 12 Aug 2024 22:46:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-1/</guid>
      <description>&lt;p&gt;本文为&lt;em&gt;Programming Massively Parallel Processors A Hands-on Approach 4th Edition&lt;/em&gt;（中文名：大规模并行处理器编程实战）第一部分学习笔记，包括全书前六章。&lt;/p&gt;
&lt;p&gt;全书第一部分主要内容有：CUDA 架构、CUDA C 编程入门、CUDA 优化技术简介。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>本文为<em>Programming Massively Parallel Processors A Hands-on Approach 4th Edition</em>（中文名：大规模并行处理器编程实战）第一部分学习笔记，包括全书前六章。</p>
<p>全书第一部分主要内容有：CUDA 架构、CUDA C 编程入门、CUDA 优化技术简介。</p>
<h1 id="chapter-1-introduction-简介">Chapter 1: Introduction 简介</h1>
<p>应用程序需求的算力和 CPU 能够提供的算力一直是一对相互促进的矛盾。上世纪八九十年代，通过不断提高单核频率和每个时钟周期执行的活动数让算力达到了 TFLOPS 的级别。然而，到了 21 实际，由于功率和散热限制，难以通过提升频率进一步提高算力。这种情况下，多核 CPU 就应运而生了。多核 CPU 可以同时执行多个指令序列，因此应用程序也必须将任务分为多个部分以便在多个核心上同时执行。如果不针对多核进行优化，那程序很难享受到多核带来的算力提升。</p>
<p>这类能够享受到多核性能提升的程序被称为并行程序 parallel programs。</p>
<h2 id="11-heterogeneous-parallel-computing-异构并行计算">1.1 Heterogeneous parallel computing 异构并行计算</h2>
<p>2003 年，在处理的进化道路上出现了一个分岔口。</p>
<p>一种以多核 multicore 见长，每个核心都是完整的一个单核 CPU，这就是现代的多核 CPU。例如 Intel 发布的最新处理器中，往往具有十几个核心，每个核心都具有超线程能力，并且完整实现了 x86 指令集。</p>
<p>另一种以多线程 many-thread 见长，能够同时执行非常非常多的线程，往往具有极强的浮点计算能力，这就是现代 GPU。例如 NVDIA 发布的 A100 GPU 中，其双精度浮点算力达到 9.7 TFLOPS，同期的 Intel 24 核处理器只有 0.66 TFLOPS。</p>
<p>如下图所示，这一差异源自二者设计理念的差别。CPU 为了支持顺序执行指令序列，其在设计时最小化了算数运算的延迟，并且提供了很大的末级缓存以便快速存取大量数据，还应用了许多复杂分支预测和执行控制逻辑技术来减少分支指令带来的延迟。上述技术消耗了大量的芯片面积和功耗，这种设计理念被称为面向延迟的设计。与之相反的是 GPU 的设计理念，即面向吞吐量的设计。GPU 的快速发展起初是由电子游戏推动的，每个游戏帧的渲染都需要计算大量浮点数，因此 GPU 最大化了浮点数的计算单元。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408130000613.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>对于 GPU 而言，同时进行大量的浮点计算是重要的，但是同时大量访存这一点也很重要。GPU 要能够在内存中快速移动大量数据。GPU 通常可以接受宽松内存模型 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>。</p>
<blockquote>
<p>丧心病狂的芯片研发人员为了榨取更多的性能，在 PSO 模型基础上，更进一步的放宽了内存一致性模型，不仅允许 store-load，store-store 乱序。还进一步允许 load-load，load-store 乱序， 只要是地址无关的指令，在读写访问的时候都可以打乱所有 load/store 的顺序，这就是宽松内存模型（RMO）。</p>
</blockquote>
<p>然而，作为通用处理器的 CPU 为了满足各类应用程序、老旧 OS、IO 设备等的要求，在内存上就不能这么激进了。通常，GPU 的内存带宽能够达到 CPU 的 10 倍。</p>
<p>通常来说，提高减少延迟比提高吞吐量要困难，通过让计算单元翻倍就能让吞吐量翻倍。GPU 为了提高吞吐量，增大了算术元件和内存的延迟。</p>
<p>GPU 应用程序需要有大量的并行线程，当在等待内存数据时，GPU 的其它线程可以用于查找接下来要完成的任务。这类设计模式被称为面向吞吐量设计。</p>
<p>GPU 执行吞吐量很高，然而其并不擅长 CPU 所擅长的领域，因此，在英伟达 2007 年引入的 CUDA 模型中，其支持 CPU-GPU 联合执行。</p>
<p>在 CUDA 出现之前，与 GPU 交互的接口为 OpenGL 和 Direct 3D，它们都是用于绘制像素的 API，即便是用 GPU 来计算，其底层仍是这些与像素相关的接口。这种技术被称为 GPGPU，general purpose GPU。</p>
<p>在 CUDA 推出以后，GPU 计算不再需要调用图形接口，而是由专用的通用计算接口。</p>
<h2 id="12-why-more-speed-or-parallelism-为什么要并行化">1.2 Why more speed or parallelism 为什么要并行化？</h2>
<p>现在普通应用已经运行得足够快了，为什么还要并行化？事实上，在很多任务中，运行速度仍是瓶颈。得益于 GPU 的迅速发展，科学计算、视频、电子游戏、深度学习等也繁荣起来。</p>
<p>以上种种应用都有一个特点，就是有大量的数据需要处理。这种情况下，可以并行执行大数据处理任务，以显著提升执行效率。</p>
<h2 id="13-speeding-up-real-applications-加速实际应用">1.3 Speeding up real applications 加速实际应用</h2>
<p>如何评价并行化后的加速倍率？我们通过比较加速前后的运行时间即可，通过加速将运行时间从 200 秒减少到 10 秒，那我们就称加速倍率为 20×。</p>
<p>一个应用程序的加速倍率，取决于该程序能够并行化的部分的比例。例如，如果一个程序有 30% 的部分可以实现 100×加速，那么这个程序的执行时间最多只能降低 29.7%，整体加速效果为 1.42×。一个系统的加速效果严重受制于可加速的部分的比例，这一定律被称为阿姆达尔定律。</p>
<p>另一个制约加速倍率的因素是内存带宽，因此在并行技术中一个重要方面就是尽可能减少主机内存访存次数，改为访问 GPU 显存。</p>
<h2 id="14-challenges-in-parallel-programming-并行编程中的挑战">1.4 Challenges in parallel programming 并行编程中的挑战</h2>
<p>编写并行程序可能很难，有些并行程序需要完成的任务可能有很多，甚至比原始版本跑得还慢。主要困难有以下几个方面。</p>
<ul>
<li>编写并行算法的思维方式和惯用的顺序执行的算法思维方式完全不同。</li>
<li>并行算法很容易受到内存贷款瓶颈。</li>
<li>并行化的算法对于输入数据的特征更加敏感。</li>
<li>并行化的算法不同线程之间可能需要协作，而这些线程之前的同步也会带来额外开销。</li>
</ul>
<h2 id="15-related-parallel-programming-interfaces-相关并行编程接口">1.5 Related parallel programming interfaces 相关并行编程接口</h2>
<p>在过去几十年中，有不少并行编程语言和模型被提出。对于共享内存的多处理器系统，最常用的是 OpenMP，对于可扩展集群计算，最常用的是 Message Passing Interface （MPI）。</p>
<p>OpenMP 由编译器和运行时两部分组成。程序员通过在代码中指定指令 directives 和编译指示 pragmas，编译器可以生成并行代码，运行时负责通过管理线程和资源以支持并行运行。OpenMP 通过提供自动编译和运行时支持使得程序员们不需要考虑并行编程的细节，也方便在不同的系统/架构中迁移</p>
<p>在 MPI 中，同一个簇内的计算节点不共享内存，所有的数据和信息通过消息传递机制进行，MPI 适合超大规模的 HPC 集群（节点超过 10 万个）。由于不共享内存，对于输入输出的分割工作，大部分由编程人员来完成。与之相反，CUDA 提供了共享内存。</p>
<p>2009 年，工业界几个巨头，包括苹果、因特尔、AMD 和英伟达一起开发了一个标准编程模型 OpenCL。</p>
<h2 id="16-overarching-goals-首要目标">1.6 Overarching goals 首要目标</h2>
<p>最首要的目标是实现在大规模并行编程中的高性能编程。本书会涉及一些对硬件架构的直觉上的理解，一些计算思维，即以适合大规模并行处理器的执行方式来思考问题。</p>
<p>第二个目标是在并行编程中实现正确的功能和可靠性。CUDA 提供了一系列工具来对代码的功能和性能瓶颈进行 Debug。</p>
<p>第三个目标是实现对未来更高性能的硬件的可扩展性。这种可扩展性是通过规范化和本地化内存，以减少在更新数据结构中对关键资源的读写和冲突来实现的。</p>
<h2 id="17-organization-of-the-book-本书的架构">1.7 Organization of the book 本书的架构</h2>
<p>略。</p>
<h1 id="chapter-2-heterogeneous-data-parallel-computing-异构数据并行计算">Chapter 2: Heterogeneous data parallel computing 异构数据并行计算</h1>
<h2 id="21-data-parallelism-数据并行化">2.1 Data parallelism 数据并行化</h2>
<p>数据彼此独立是数据并行化的基础，通过对计算任务的重新组织，可以将数据并行化，进而获得可观的加速效果。以将像素灰度化举个例子，通过如下公式来计算灰度值：</p>


<div>$$

L = 0.21 \times R&#43;0.72\times G&#43;0.03 \times B

$$</div>

<p>在上述公式中，一个位置的灰度值仅仅依赖于相同位置的 RGB 值，显然不同位置之间的灰度化过程是彼此独立的，因而可以进行并行化。</p>
<h2 id="22-cuda-c-program-structure-cuda-c-程序结构">2.2 CUDA C program structure CUDA C 程序结构</h2>
<p>CUDA C 在 ANSI C 语法的基础上，通过添加新的语法和库函数使得程序员能够针对包含有 CPU 和 GPU 的异构计算系统进行编程。</p>
<p>CUDA C 程序的结构体现出主机 host（CPU）和设备 device（GPU）是在一个计算机上共存的。一个 CUDA C 源文件可能混合有主机和设备代码，也可以认为一个纯 C 文件就是一个仅含有主机代码的 CUDA C 文件。</p>
<p>CUDA 程序的执行过程如下图所示，从主机代码开始，然后调用设备代码。核函数将会调用很多 threads 来执行，由一个 kernel 调用的所有线程的集合被称为 grid。当所有线程执行结束，程序执行又回到主机代码，直到结束或者调用另一个设备代码。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409131009864.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>注意，上图是一个简化的模型，事实上在很多异构应用中，CPU 和 GPU 执行过程可能重叠。</p>
<p>在灰度化的例子中，一个像素的灰度化可能由一个线程负责，那么图片越大，完成这个任务的线程数也就越多。得益于优秀的硬件支持，开发人员可以认为线程的创建和调度只需要几个时钟周期。而在 CPU 线程中，这一过程需要几千个时钟周期。</p>
<h2 id="23-a-vector-addition-kernel-向量加法核函数">2.3 A vector addition kernel 向量加法核函数</h2>
<p>向量加法在并行编程中的地位就像 Hello World 在顺序编程中一样。在顺序编程中，通过一个循环来实现向量加法。</p>
<p>向量加法由三部分构成，将数据从 host 搬运到 device，计算，再将数据从 device 搬运到 host。理论上来说，如果将搬运任务交给设备代码完成，那么对于设备来说，这个计算过程就是全透明的。但实际上，这部分任务由主机代码负责。</p>
<h2 id="24-device-global-memory-and-data-transfer-设备全局内存和数据搬运">2.4 Device global memory and data transfer 设备全局内存和数据搬运</h2>
<p>在 device 中，其一般都是带有自己的 RAM，被称为全局内存。前面提到，在 device 计算前后，数据要从 host mem 搬运到 gloabl mem，这一过程由运行在 host 上的 CUDA 运行时提供的 API 来完成。</p>
<p>有两个 API 用于申请和释放内存。<code>cudaMalloc</code> 用于申请内存，参数为一个指针的地址和内存大小（单位：字节），分配好的内存首地址将被写入传入的指针。<code>cudaFree</code> 用于释放内存。在主机代码中不得解引用 device mem，这会导致异常或者其它运行时错误。</p>
<p>内存分配结束后，就可以将数据从 host mem 拷贝到 global mem。使用的是 <code>cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind )</code> 这个 API，包括四个参数：目的地址、源地址、字节数、类型。类型字段用于指定拷贝的方向，有四种方向 host/device to host/device。</p>
<h2 id="25-kernel-functions-and-threading-核函数和线程">2.5 Kernel functions and threading 核函数和线程</h2>
<p>核函数指的是 GPU 线程并行执行的代码，这是一种典型的 SPMD 范式。当主机端调用一个核函数，所有的线程被组织为两级结构：一个核函数由一个 grid 运行，一个 grid 含有多个 blocks，一个 block 内有多个 threads。每个 block 内 threads 的数量都是相同的，且最多为 1024 个。</p>
<p>每一个线程内都有一个有运行时负责维护的内建变量 <code>blockDim</code>，其包括三个数据域 <code>x,y,z</code>，用于记录一个 block 内线程的数量。三个数据域说明其支持将一个 block 中的所有 thread 按照最多三维的形式组织，以便与待处理的数据有更好的对应关系。出于性能考虑，建议每个维度的数量均为 32 的整数倍。</p>
<p>还有两个内建变量 <code>threadIdx</code> 和 <code>blockIdx</code> 分别 thread 在 block 内部的索引和 block 在 gird 内部的索引。使用公式 <code>int i = blockDim * block + blockIdx</code> 可以计算每个 thread 的全局索引，如过让每个 thread 负责向量加法中一个元素的计算，那么 n 个 thread 就可以计算长度不超过 n 的向量加法，对应的核函数实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">vecAddKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="p">(</span><span class="n">i</span><span class="o">&lt;</span><span class="n">n</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意到这里使用了限定修饰符 <code>__global__</code> 用于生命此函数既可以 host 调用，也可以被 device 调用。CUDA C 引入了还引入了两个关键字 <code>__host__</code> 和 <code>__device__</code>，前者是默认行为，表示该函数在 host 上运行，只能被 host 调用；后者则表示该函数在 device 上运行，只能被 device func 或者 kernel 调用，device func 本身不会新建任何线程。</p>
<p>此外，可以同时使用 <code>__host__</code> 和 <code>__device__</code> 修饰一个函数，这意味着编译器将分别为 host 和 device 生成不同的版本</p>
<h2 id="26-calling-kernel-functions-调用核函数">2.6 Calling Kernel functions 调用核函数</h2>
<p>完整的调用过程如下所示：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">vecAdd</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="o">*</span><span class="n">A_d</span><span class="p">,</span> <span class="o">*</span><span class="n">B_d</span><span class="p">,</span> <span class="o">*</span><span class="n">C_d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">size</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMalloc</span><span class="p">((</span><span class="kt">void</span> <span class="o">**</span><span class="p">)</span> <span class="o">&amp;</span><span class="n">A_d</span><span class="p">,</span> <span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMalloc</span><span class="p">((</span><span class="kt">void</span> <span class="o">**</span><span class="p">)</span> <span class="o">&amp;</span><span class="n">B_d</span><span class="p">,</span> <span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMalloc</span><span class="p">((</span><span class="kt">void</span> <span class="o">**</span><span class="p">)</span> <span class="o">&amp;</span><span class="n">C_d</span><span class="p">,</span> <span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMemcpy</span><span class="p">(</span><span class="n">A_d</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">cudaMemcpyHostToDevice</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMemcpy</span><span class="p">(</span><span class="n">B_d</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">cudaMemcpyHostToDevice</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">vecAddKernel</span><span class="o">&lt;&lt;&lt;</span><span class="nf">ceil</span><span class="p">(</span><span class="n">n</span><span class="o">/</span><span class="mf">256.0</span><span class="p">),</span> <span class="mi">256</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">A_d</span><span class="p">,</span> <span class="n">B_d</span><span class="p">,</span> <span class="n">C_d</span><span class="p">,</span> <span class="n">n</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMemcpy</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">C_d</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">cudaMemcpyDeviceToHost</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaFree</span><span class="p">(</span><span class="n">A_d</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaFree</span><span class="p">(</span><span class="n">B_d</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaFree</span><span class="p">(</span><span class="n">C_d</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>捅过 <code>&lt;&lt;&lt;</code> 和 <code>&gt;&gt;&gt;</code> 指定在调用 kernel 时 block 和每个 block 中 thread 的数量，使用向上取整确保向量加法没有遗漏。</p>
<p>在执行过程中，block 的调度对程序员是透明的，其取决于 GPU 的规模和运算速度。block 之间是独立的。</p>
<h2 id="27-compilation-编译">2.7 Compilation 编译</h2>
<p>CUDA C 的编译运行过程如下所示，首先由 NVCC 将主机代码和设备代码进行分离，主机代码交由主机的 C 编译器进行编译链接，设备代码将被编译为 PTX 的虚拟二进制格式，然后再由设备的 JIT 进行二次编译运行。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409171854831.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h1 id="chapter-3-multidimensional-girds-and-data-多维网格和数据">Chapter 3: Multidimensional girds and data 多维网格和数据</h1>
<h2 id="31-multidimensional-grid-organization-多维网格组织">3.1 Multidimensional grid organization 多维网格组织</h2>
<p>前面提到，grid 和 block 都能以多维的形式进行组织。这个数量也并非无上限，<code>gridDim.x</code> 的最大值是 <code>2^31-1</code>，<code>gridDim.y/z</code> 的最大值均是 <code>2^16-1</code>。而对于 block 来说，其内部线程的数量约束为线程的数量不超过 1024。</p>
<p>可以使用 <code>dim3</code> 类型定义描述 gird 和 block 形状的变量，按照 x、y、z 的顺序：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="n">dim</span> <span class="nf">dimGrid</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">dim</span> <span class="nf">dimBlock</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>而我们在使用下标来索引描述一个 block 或者 thread 的过程，则是按照 z、y、x 的顺序，例如 block(1, 0) 表示 x 为 0、y 为 1 的 block，这种顺序便于描述线程和数据之间的映射关系。</p>
<h2 id="32-mapping-threads-to-multidimensional-data-将线程映射到多维数据">3.2 Mapping threads to multidimensional data 将线程映射到多维数据</h2>
<p>线程的组织形式取决于数据的内在结构。例如，对于图片数据来说，按照二维来组织线程有利于处理像素。</p>
<p>例如，如果要对每个 62*76 的图片进行像素处理，可以使用 16×16 的线程组织为一个 block，4×5 的 block 组织为一个 grid。最终按照如下所示的形式对原图片（阴影部分）进行覆盖。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409181411198.png?x-oss-process=image/quality,q_90/format,webp"><br>
在 CUDA C 中，数据按照行优先的顺序在内存中存储。</p>
<h2 id="image-blur-a-more-complex-kernel-一个更复杂的核函数图片模糊">Image blur: a more complex kernel 一个更复杂的核函数：图片模糊</h2>
<p>在本小节中将实现一个更为复杂的核函数：图片模糊，图片模糊是对每个像素包括它自己的周围区域像素取加权均值得到，本节中权重均为 1，在实际应用中往往根据到中心点的远近取不同的权重值。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="n">__global__</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">blurKernel</span><span class="p">(</span><span class="kt">unsigned</span> <span class="kt">char</span> <span class="o">*</span><span class="n">in</span><span class="p">,</span> <span class="kt">unsigned</span> <span class="kt">char</span> <span class="o">*</span><span class="n">out</span><span class="p">,</span> <span class="kt">int</span> <span class="n">w</span><span class="p">,</span> <span class="kt">int</span> <span class="n">h</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="o">*</span><span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">col</span> <span class="o">&lt;</span> <span class="n">w</span> <span class="o">&amp;&amp;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">h</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">pixVal</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">pixels</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// Get average of the surrounding BLUR_SIZE x BLUR_SIZE box
</span></span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">blurRow</span><span class="o">=-</span><span class="n">BLUR_SIZE</span><span class="p">;</span> <span class="n">blurRow</span><span class="o">&lt;</span><span class="n">BLUR_SIZE</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="o">++</span><span class="n">blurRow</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">blurCol</span><span class="o">=-</span><span class="n">BLUR_SIZE</span><span class="p">;</span> <span class="n">blurCol</span><span class="o">&lt;</span><span class="n">BLUR_SIZE</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="o">++</span><span class="n">blurCol</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">                <span class="kt">int</span> <span class="n">curRow</span> <span class="o">=</span> <span class="n">row</span> <span class="o">+</span> <span class="n">blurRow</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                <span class="kt">int</span> <span class="n">curCol</span> <span class="o">=</span> <span class="n">col</span> <span class="o">+</span> <span class="n">blurCol</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                <span class="c1">// Verify we have a valid image pixel
</span></span></span><span class="line"><span class="cl">                <span class="k">if</span><span class="p">(</span><span class="n">curRow</span><span class="o">&gt;=</span><span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">curRow</span><span class="o">&lt;</span><span class="n">h</span> <span class="o">&amp;&amp;</span> <span class="n">curCol</span><span class="o">&gt;=</span><span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">curCol</span><span class="o">&lt;</span><span class="n">w</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">pixVal</span> <span class="o">+=</span> <span class="n">in</span><span class="p">[</span><span class="n">curRow</span><span class="o">*</span><span class="n">w</span> <span class="o">+</span> <span class="n">curCol</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                    <span class="o">++</span><span class="n">pixels</span><span class="p">;</span> <span class="c1">// Keep track of number of pixels in the avg
</span></span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// Write our new pixel value out
</span></span></span><span class="line"><span class="cl">        <span class="n">out</span><span class="p">[</span><span class="n">row</span><span class="o">*</span><span class="n">w</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="kt">unsigned</span> <span class="kt">char</span><span class="p">)(</span><span class="n">pixVal</span><span class="o">/</span><span class="n">pixels</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="34-matrix-multiplication-矩阵乘法">3.4 Matrix multiplication 矩阵乘法</h2>
<p>矩乘是线性代数算法中基础算法之一，矩乘定义不再赘述。类似于之前一个 thread 负责一个位置计算的思想，可以写出最朴素版本的矩阵核函数：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">MatrixMulKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">M</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">N</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                <span class="kt">float</span><span class="o">*</span> <span class="n">P</span><span class="p">,</span> <span class="kt">int</span> <span class="n">Width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">((</span><span class="n">row</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="p">(</span><span class="n">col</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="p">))</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">M</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">Width</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span><span class="p">[</span><span class="n">k</span> <span class="o">*</span> <span class="n">Width</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">P</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">Width</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="chapter-4-compute-architecture-and-scheduling-计算架构和调度">Chapter 4: Compute architecture and scheduling 计算架构和调度</h1>
<p>本章节从现代 GPU 架构讲起，主要介绍 GPU 在执行过程中线程的调度机制，以及制约占用率的一些因素。</p>
<h2 id="41-architecture-of-a-modern-gpu-现代-gpu-架构">4.1 Architecture of a modern GPU 现代 GPU 架构</h2>
<p>下图展示了在程序员视角中的 GPU 架构。其由一系列流式多处理器 streaming multiprocessors SM 组成，每个 SM 由多个流处理器或者称 CUDA 核心组成。在一个 SM 内部，其共享控制单元和 on-chip mem。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409191109611.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="42-block-scheduling-block-调度">4.2 Block scheduling Block 调度</h2>
<p>当一个核函数被调用时，CUDA 运行时会启动一系列线程，这些线程以 block 为单位分配给 SM，一个 SM 可以被分配多个 block。当然，由于硬件资源的限制，一个 SM 分配的 block 数量是有限的。由于 SM 和 SM 分配的 block 的数量有限，同时执行的 thread 数量也是有限的，这就以为着注定有一些线程不是并行执行的。</p>
<h2 id="43-synchronization-and-transparent-scalability-同步和透明可扩展性">4.3 Synchronization and transparent scalability 同步和透明可扩展性</h2>
<p>CUDA 允许线程之间通过 <code>__syncthreads()</code> 函数进行彼此协调（同步）。当一个线程调用该函数，其会在此处阻塞，直至所有线程均执行到这里。这种同步技术被称为屏障同步 barrier synchronization。</p>
<p>在 CUDA 中，如果使用 <code>__syncthreads</code> 进行线程同步，一个 block 内的所有线程都必须执行该同步。如果该语句在条件分支中，则该 block 内的线程要么都经过这个分支，要么都不经过这个分支，不能出现“部分同步”的现象，这是未定义行为。</p>
<p>同一个 block 内的 thread 抵达屏障阻塞的时间应当大致相同，CUDA 运行时会确保同一个 block 内的 thread 同时开始执行。</p>
<p>上述屏障同步机制可以看出 CUDA 在设计时的折中：通过禁止跨 block 的线程同步，这使得 CUDA 能够以任意顺序调度这些 block。进一步地，这种调度的任意性为透明可扩展性奠定了基础：对于 SM 比较少的设备来说，其可以每次执行少量的 block，对于 SM 很多的设备来说，其也许能够一下子调度所有的 block 进行执行。这种在不同的设备上执行同一份代码的能力，被称为透明可扩展性。</p>
<h2 id="44-warps-and-simd-hardware-线程束和-simd-硬件">4.4 Warps and SIMD hardware 线程束和 SIMD 硬件</h2>
<p>SM 内部线程的调度策略取决于具体的硬件实现，在目前大部分的设备中，一个 block 内的线程会按照 32 为一个单位组成线程束，SM 内部的调度以线程束为单位进行。如果剩余线程不满 32 个，则会填补一些非激活线程凑满一个线程束。</p>
<p>对于多维 block，首先将内部的线程按照一维线性排列，然后划分线程束。</p>
<p>SM 内的线程束遵循 SIMD 模式，即每次取一条执行，由线程束内的所有线程一起执行。这些线程束内的线程共享同一套控制单元，具有相同的执行进度，这一模型被称为 SIMT 单指令多线程模型。</p>
<h2 id="45-control-divergence-控制流分歧">4.5 Control divergence 控制流分歧</h2>
<p>如果线程束内的线程共享同一个控制流，那么按序执行即可；如果这些线程进入不同的控制分支时，那么这些线程束会依次进入所有必要的分支，在每个分支中对应的线程会被激活，其它线程则保持静止。</p>
<p>一个线程束内的线程具有不同的控制流路径，这一现象被称为控制流分歧。通过在不同的分支中激活不同的线程，CUDA 实现了线程束内控制流的完整语义，但其代价就是要依次通过所有必要的控制流。</p>
<p>在帕斯卡及其以前的架构中，这些分歧的控制流一定是依次通过的；但是在伏特极其以后的架构中，分歧控制流可能能够并行执行，这一特性被称为独立线程调度。</p>
<p>除了单纯的 if 语句，任何具有控制流的语句都可能导致控制流分歧，例如下面这段 for 循环的代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">N</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">N</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="c1">//...
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>对于循环语句，可以根据其退出条件判断是否存在控制流分歧。如果条件中涉及到了 <code>threadIdx</code>，就有可能存在分歧。</p>
<p>使用控制语句的一个主要场景是用于处理数据的边界情况。具体来说，经常需要将数据划分给线程，但是数据量不一定是线程数的整数倍，因此需要控制语句判断边界，防止越界存取。</p>
<p>随着数据量的增大，控制流分歧带来的代价反而会降低。对于 100 个数据，有 1/4 的线程束内部存在分歧；对于 1000 个数据，只有 1/32 的线程束内部存在分歧。</p>
<h2 id="46-wrap-scheduling-and-latency-tolerance-线程束调度和延迟容忍">4.6 Wrap scheduling and latency tolerance 线程束调度和延迟容忍</h2>
<p>一个 SM 被分配到的线程数往往比核心更多，这意味只有一部分线程能够立刻执行。在早期 GPU 设计中，一个 SM 每次只能执行一个线程束的一条指令；而在最近的 GPU 设计中，每个 SM 能够执行来自不同线程束的指令。但无论如何，每个 SM 每次也只能执行一部分而非所有的线程束。</p>
<p>那为什么要分配超过其执行能力数量的线程束给单个 SM 呢？答案在于 GPU 需要通过这种方式来克服一些延迟比较大的操作带来的开销，例如访问全局内存的操作。</p>
<p>当一个线程束需要长时间等待访存操作完成时，线程束将被阻塞执行，同时调度器将调度就绪的线程执行。这一技术被称为延迟容忍或者延迟隐藏。</p>
<p>值得一提的是，在 CUDA 中的线程调度几乎不存在开销，这被称为零开销线程调度。CUDA 延迟容忍的特性使得其不需要像 CPU 一样在计算单元附近放置大量 cache，相反，可以放置大量浮点计算单元。</p>
<p>要实现高效的延迟容忍，一个 SM 就需要被分配大量线程束，使得其随时都能调度就绪线程束。在 A100 GPU 中，一个有 64 核的 SM 甚至最多能分配 2048 个线程。</p>
<h2 id="47-resource-partitioning-and-occupancy-资源分配与占用">4.7 Resource partitioning and occupancy 资源分配与占用</h2>
<p>SM 实际分配的线程数与最大可分配数量的比值被称为占用率，占用率越高，SM 潜在执行效率也就越高。制约占用率最大化的一个因素是资源分配。</p>
<p>执行资源包括寄存器、共享内存、线程块槽位和线程槽位。这些资源被动态分配给线程以支持其执行。举个栗子，A100 最大支持每个 SM 32 blocks、每个 SM 64 个线程束、每个 block 1024 个线程。如果一个核函数按照每个 block 1024 个线程进行启动，那么每个 block 中有 32 个线程束，受制于第二个线程束条件，每个 SM 只能被分配 2 个 block。</p>
<p>资源的动态分配使得 SM 既可以执行少数具有大量 thread 的 block，也可以执行大量具有少数 thread 的 block。但这也可能导致低占用率。例如，在前面的栗子中，如果每个 block 只有一个线程束即 32 个线程，受制于 block 数量约束，这个 SM 最多只能分配 32 个线程束，占用率为 50%。又或者，如果每个 Block 的线程数不能被 SM 最大线程数整除，例如一个 Block 有 768 个线程即 24 个线程束，那么这个 SM 最多只能分配 48 个线程束，占用率为 75%。</p>
<p>寄存器同样是有限的资源，具体将在下一章中讨论。</p>
<h2 id="48-querying-device-properties-查询设备属性">4.8 Querying device properties 查询设备属性</h2>
<p>怎么在运行时查询设备参数呢？每个 SM 资源的数量由计算能力 compute capability 定义，CUDA C 提供了 <code>cudaGetDeviceProperties</code>API 用于获取设备的具体属性，其将写入一个 <code>cudaDeviceProp</code> 类型的返回值，该类型记录了 GPU 的详细信息。重点介绍如下字段：</p>
<ul>
<li><code>maxThreadsPerBlock</code>：每个 block 最多线程数</li>
<li><code>multiProcessorCount</code>：设备中 SM 的数量</li>
<li><code>clockRate</code>：时钟周期，与上一字段一起指示了该设备的吞吐量</li>
<li><code>maxThreadsDim</code>：每个 block 中线程各维度上限</li>
<li><code>maxGridSize</code>：grid 中 block 各维度上限</li>
<li><code>regsPerBlock</code>：一个 block 中寄存器上限，这个字段含义实际上指的是一个 SM 中可用寄存器数量，而非占用率 100% 时一个 block 中可数量的数量</li>
<li><code>warpSize</code>：线程束中线程数量</li>
</ul>
<h1 id="chapter-5-memory-architecture-and-data-locality-内存架构和数据索引">Chapter 5: Memory architecture and data locality 内存架构和数据索引</h1>
<p>本章节系统介绍了 CUDA 中的内存架构，阐释了访存效率对于计算效率的影响，使用分块技术对矩乘进行改进，实现了 16x 的性能提升。</p>
<h2 id="51-importance-of-memory-access-efficiency-访存效率的重要性">5.1 Importance of memory access efficiency 访存效率的重要性</h2>
<p>以 <a href="/notes/note-on-programming-massively-parallel-processors-a-hands-on-approach-4th-edition-part-1/#3.4-matrix-multiplication-%E7%9F%A9%E9%98%B5%E4%B9%98%E6%B3%95">3.4 Matrix multiplication 矩阵乘法</a> 中给出的矩阵乘法为例，每个线程负责结果中一个元素的计算，即计算两个向量的内积。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">MatrixMulKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">M</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">N</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                <span class="kt">float</span><span class="o">*</span> <span class="n">P</span><span class="p">,</span> <span class="kt">int</span> <span class="n">Width</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">((</span><span class="n">row</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="p">(</span><span class="n">col</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="p">))</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">M</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">Width</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span><span class="p">[</span><span class="n">k</span> <span class="o">*</span> <span class="n">Width</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">P</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">Width</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>如上所示，在计算内积的循环中，每次都要访问两次全局内存，以及 2 次浮点计算。这里引入一个指标“计算密度” computational intensity，计算公式为浮点操作数 FLOP 比上全局访存字节数，在上述代码中，计算密度为 2FLOP/8B = 0.25FLOP/B。</p>
<p>计算密度这一指标能够指示该 CUDA 程序是否充分利用了核心的计算能力。例如，在 A100 张中，全局访存带宽为 1555GB/s，将其与计算密度相乘，可以得到该程序所需的浮点计算能力为 389GFLOPS，远低于 A100 实际具有的浮点计算能力 19500GFLOPS，遑论 A100 中还具有专门的 tensor core，具有 156000GFLOPS 的浮点算力。</p>
<p>这类被内存带宽拖累的程序被称为内存瓶颈程序。根据 A100 的全局带宽和浮点计算能力，我们可以计算出至少需要 19500/1555=12.5FLOP/B 的计算密度才能充分发挥其计算性能。</p>
<h2 id="52-cuda-memory-types-cuda-内存类型">5.2 CUDA memory types CUDA 内存类型</h2>
<p>CUDA 设备提供了多种内存类型以提高计算密度，如下所示。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409231026587.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>最底层的为全局和常量内存，host 可以对其读写，device 可以对全局内存读写，可以对常量内存以低延迟和高带宽进行读，而不可写。</p>
<p>还有一部分是局部内存，其实际上是全局内存的一部分，与全局内存具有类似的延迟和带宽。每个线程会在全局内存中分配一段仅其自己可读写的内存作为局部内存，用于存放寄存器放不下的变量，例如静态数组、溢出的寄存器以及线程的函数调用栈。</p>
<p>寄存器和共享内存是片上内存，其中的变量能够并行地以非常高的速率被访问。寄存器仅对该线程自己可见，用于保存线程经常使用到的一些仅自己可见的变量，共享内存则由一个 Block 内的所有变量共享。</p>
<p>通过使用不同的内存类型，程序员可以控制不同变量的访问速度和可见性。</p>
<p>除了本身的延迟和带宽，访问寄存器更快还有一个原因是指令数量。将两个寄存器中的浮点数相加只需要一条浮点数加法指令，而如果两个不在寄存器中的浮点数加起来则需要额外指令将数据加载到寄存器、将结果搬运回内存。执行这些额外指令本身也会消耗更长的时间。</p>
<p>尽管寄存器和共享内存都是片上内存，但共享内存是内存体系的一部分，其中数据也需要读取到寄存器中再操作，因此相比寄存器其具有更高的延迟和更低的吞吐量。术语 scratchpad memory 指的就是这一部分板上内存。</p>
<p>声明的不同类型的变量其保存的位置、作用域和声明周期各不相同，具体对应关系如下表所示：</p>
<table>
  <thead>
      <tr>
          <th>变量声明</th>
          <th>内存</th>
          <th>作用域</th>
          <th>生命周期</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>非数组的自动变量</td>
          <td>寄存器</td>
          <td>thread</td>
          <td>网格</td>
      </tr>
      <tr>
          <td>自动数组变量</td>
          <td>本地</td>
          <td>thread</td>
          <td>网格</td>
      </tr>
      <tr>
          <td><code>__device__ __shared__ int SharedVar;</code></td>
          <td>共享</td>
          <td>block</td>
          <td>网格</td>
      </tr>
      <tr>
          <td><code>__device__ int GlobalVar;</code></td>
          <td>全局</td>
          <td>grid</td>
          <td>应用程序</td>
      </tr>
      <tr>
          <td><code>__device__ __constant__ int ConstVar;</code></td>
          <td>常量</td>
          <td>grid</td>
          <td>应用程序</td>
      </tr>
  </tbody>
</table>
<h2 id="53-tiling-for-reduced-memory-traffic-通过分块减少访存">5.3 Tiling for reduced memory traffic 通过分块减少访存</h2>
<p>将数据划分为在共享内存中放得下的小块可以减少对全局内存的访问，数据分块的前提是每一块都可以独立地进行计算，不是所有的数据结构、也不是所有的核函数都可以进行分块处理。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409251414411.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>如上图所示，正在之前实现的矩阵乘法中，每个线程独立计算一个元素，第一个 block 由四个线程组成，这些线程之间有重复读取全局内存的过程，例如 P00 和 P01 均读取了 M 的第一行。可以通过将这些元素读入共享内存来实现对全局内存的减半访问。</p>
<p>在矩乘中，实际减少的访存次数取决于 block 的 size，具体来说，如果 block 中的线程以 n×n 的规格组织，则能够将访存次数减少到 1/n。</p>
<p>需要注意的是，共享内存的大小是有限的，如果一个 block 中的线程数过多或者矩乘中的维度过大，共享内存可能存不下分块后需要用到的数据，此时可以将其划分为更小的块以便读入共享内存中。</p>
<p>例如，按照 2×2 对 M 和 N 进行分块，4×4 的矩乘将由两阶段完成。对于 block00 来说，第一阶段将 <code>M[0:2, 0:2]</code> 和 <code>N[0:2, 0:2]</code> 读入共享内存计算矩乘；第二阶段将 <code>M[0:2 2:4]</code> 和 <code>N[2:4, 0:2]</code> 读入共享内存，计算矩乘并累加到前面的结果中。各线程完成的任务如下表所示：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409251455300.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="54-a-tiled-matrix-multiplication-kernel-分块矩乘核函数">5.4 A tiled matrix multiplication kernel 分块矩乘核函数</h2>
<p>分块矩乘核函数如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="cp">#define TILE_WIDTH 16
</span></span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">M</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">N</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">P</span><span class="p">,</span> <span class="kt">int</span> <span class="n">Width</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">Mds</span><span class="p">[</span><span class="n">TILE_WIDTH</span><span class="p">][</span><span class="n">TILE_WIDTH</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">Nds</span><span class="p">[</span><span class="n">TILE_WIDTH</span><span class="p">][</span><span class="n">TILE_WIDTH</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">bx</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="kt">int</span> <span class="n">by</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">tx</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="kt">int</span> <span class="n">ty</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1">// Identify the row and column of the P element to work on
</span></span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">Row</span> <span class="o">=</span> <span class="n">by</span> <span class="o">*</span> <span class="n">TILE_WIDTH</span> <span class="o">+</span> <span class="n">ty</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">Col</span> <span class="o">=</span> <span class="n">bx</span> <span class="o">*</span> <span class="n">TILE_WIDTH</span> <span class="o">+</span> <span class="n">tx</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1">// Loop over the M and N tiles required to compute P element
</span></span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">Pvalue</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ph</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ph</span> <span class="o">&lt;</span> <span class="n">Width</span><span class="o">/</span><span class="n">TILE_WIDTH</span><span class="p">;</span> <span class="o">++</span><span class="n">ph</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// Collaborative loading of M and N tiles into shared memory
</span></span></span><span class="line"><span class="cl">        <span class="n">Mds</span><span class="p">[</span><span class="n">ty</span><span class="p">][</span><span class="n">tx</span><span class="p">]</span> <span class="o">=</span> <span class="n">M</span><span class="p">[</span><span class="n">Row</span><span class="o">*</span><span class="n">Width</span> <span class="o">+</span> <span class="n">ph</span><span class="o">*</span><span class="n">TILE_WIDTH</span> <span class="o">+</span> <span class="n">tx</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">Nds</span><span class="p">[</span><span class="n">ty</span><span class="p">][</span><span class="n">tx</span><span class="p">]</span> <span class="o">=</span> <span class="n">N</span><span class="p">[(</span><span class="n">ph</span><span class="o">*</span><span class="n">TILE_WIDTH</span> <span class="o">+</span> <span class="n">ty</span><span class="p">)</span><span class="o">*</span><span class="n">Width</span> <span class="o">+</span> <span class="n">Col</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">TILE_WIDTH</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">Pvalue</span> <span class="o">+=</span> <span class="n">Mds</span><span class="p">[</span><span class="n">ty</span><span class="p">][</span><span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">Nds</span><span class="p">[</span><span class="n">k</span><span class="p">][</span><span class="n">tx</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">P</span><span class="p">[</span><span class="n">Row</span><span class="o">*</span><span class="n">Width</span> <span class="o">+</span> <span class="n">Col</span><span class="p">]</span> <span class="o">=</span> <span class="n">Pvalue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>与之前的分析类似，首先声明两个共享内存用于存放当前阶段计算需要用到的数据。在阶段的循环中，首先 co-fetch 数据到共享结存中然后进行矩乘计算。这里使用了两次同步，第一次是防止数据还没有加载完就进行读取，第二次是防止计算还没完成就写入下一阶段的数据。</p>
<p>16-26 行演示一种被称为 strip-mining 的技术，即将原始很长的循环划分为多个阶段进行，每个阶段内部有一个嵌套循环负责执行原循环中连续的一小部分。</p>
<p>通过分块，我们将矩乘核函数的计算密度从 0.25 OP/B 提升到了 4 OP/B，这是 16 倍的提升。当然，离 A100 12.5 OP/B 还有很远的距离。更多的优化方法将在后文中继续讨论。</p>
<h2 id="55-boundary-check-边界检查">5.5 Boundary check 边界检查</h2>
<p>这节私以为没有单独拎出来的必要，核心内容就是在加载数据和计算时都要进行边界检查，不要越界。此节跳过。</p>
<h2 id="56-impact-of-memory-usage-on-occupancy-内存使用对使用率的影响">5.6 Impact of memory usage on occupancy 内存使用对使用率的影响</h2>
<p>正如第四章所提到的，寄存器和共享内存的过度使用将成为制约每个 SM 中分配到的线程数的负面因素。例如，在 A100 中，每个 SM 共享内存大小为 164KB，按照按照最大线程数 2048 计算，一个 block 中平均每个线程使用的共享内存大小不能超过 164KB/2048 = 82B。而在我们之前的矩乘中，每个线程平均加载了 2 个浮点数，即 8B，小于 82B。这说明在之前的核函数中，内存使用不会成为瓶颈。</p>
<p>可以针对不同的硬件平台，使用不同大小的共享内存。这涉及到了 CUDA 中动态分配共享内存技术，使用关键字 <code>extern __shared__</code> 来声明一个动态共享内存：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="k">extern</span> <span class="n">__shared__</span> <span class="n">Mds_Nds</span><span class="p">[];</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>该动态数组只有一个，如果由多个变量共享，需要由程序员控制不同变量之间的边界。</p>
<p>在调用核函数时，使用第三个参数动态传入共享内存的大小。还可以在核函数的参数中增加相应的字段用于表示共享内存中不同变量的长度。</p>
<h1 id="chapter-06-performance-considerations-性能考量">Chapter 06: Performance considerations 性能考量</h1>
<p>本章节介绍了几种性能优化技术：内存合并访问、隐藏内存延迟和线程粗化，并对本书第一部分所提及的性能优化技术进行小结。</p>
<h2 id="61-memory-coalescing-内存合并访问">6.1 Memory coalescing 内存合并访问</h2>
<p>在上一章中介绍了共享内存以缓解全局内存带宽瓶颈，在本章中将介绍内存合并访问技术以更高效地在全局内存和共享内存之间搬运数据。</p>
<p>DRAM 的物理结构决定了其支持突发访存，即当访问某个位置的元素时，其周围连续的元素也会被一起读取。</p>
<p>为了充分利用突发访存的特性，CUDA 会自动将线程束中的多个线程连续的访存指令转换为突发访存指令，即如果线程束中 0-31 号线程的同一个访存指令访问的目标是全局内存中连续的 32 个位置，则该访存指令将通过突发访存来实现。</p>
<p>例如，在如下的矩乘实现中，每个线程从第二个矩阵中取一列进行计算。在主循环中，每个线程从其对应列中取一个元素，从线程束的视角来看，同一个线程束内部的线程每轮循环都在对连续的位置进行访存。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409270939652.png?x-oss-process=image/quality,q_90/format,webp"><br>
与之相反的是，如果第二个矩阵在内存中按照列优先排布，那么每个线程内部访存是连续的，但是线程束内线程之间访存是非连续的，这使得内存合并访问技术不可用。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409270944173.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>针对上述用情况，有几种解决方案：修改线程与数据之间的映射方式、修改数据的排布方式、使用共享内存通过内存合并访问加载数据。第三种技术被称为数据重排 corner turning.</p>
<p>具体来说，数据重排就是通过内存合并的方式将在计算过程中不能合并访问的数据加载到共享内存中，共享内存使用 SRAM，因此无论按照行/列优先访问都具有相同的速度。</p>
<h2 id="62-hiding-memory-latency-隐藏内存延迟">6.2 Hiding memory latency 隐藏内存延迟</h2>
<p>仅使用突发访存这一并行访存技术并不能满足 CPU/GPU 对于访存带宽的需求，还需要使用多通道技术。</p>
<p>每个处理器上都有具备独立内存控制器的多个通道，每个通道通过独立总线连接至不同的内存上，从而实现并行内存读取。</p>
<p>总线上数据的传输带宽取决于时钟频率和字长，现代 DDR 总线在每个上升和下降沿都可以传输一次数据，因此 64-bit 且频率为 1 Ghz 的总线带宽为 16 GB/s。</p>
<p>单个存储片除了突发访存时间外还有很大一部分比例的时间（远大于访存时间）用于准备数据，因此单个通道可能会连接到多个存储片，通过流水线来最大化利用总线带宽。</p>
<p>本节剩余内容似乎在讲交叉存储器，看计组英文版有点头疼，笔记跳过。</p>
<h2 id="63-thread-coarsening-线程粗化">6.3 Thread coarsening 线程粗化</h2>
<p>在前面的案例中，我们大多按照最小粒度为线程划分任务，例如每个线程负责一个元素的计算。但由于并行化开销的存在，并非线程越多越好。此外，由于硬件限制，部分线程可能需要等待其它线程完成后才能被调度，这种情况下，并行开销反而是无意义的，可以使用线程粗化技术，给每个线程分配多个任务单元。</p>
<p>如下图所示，在矩乘中，每个线程可以负责计算两个连续的子块，其只需要读取一次 M 的同一行，减少了访存次数。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202410100910179.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>线程粗化可能能够显著提升性能，仅当并行化存在开销的情况，例如重复加载数据、重复计算、同步开销等，否则可能导致粗化不能改进性能。第二个缺点是粗化可能导致硬件利用率下降，硬件利用率依赖于高度并行化，而线程粗化会降低这一点。第三个缺点是粗化可能导致占用率下降，具体来说，粗化后的内核程序可能需要更多的寄存器和共享内存资源，进而导致占用率下降。</p>
<h2 id="64-a-checklist-of-optimizations-优化清单">6.4 A checklist of optimizations 优化清单</h2>
<p>到本章结束，本书的第一部分就已介绍完毕。在第一部分中，介绍了几种优化策略，总结如下：</p>
<table>
  <thead>
      <tr>
          <th>优化项</th>
          <th>对计算核心的好处</th>
          <th>对内存的好处</th>
          <th>策略</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>最大化占用率</td>
          <td>更多工作以隐藏流水线延迟</td>
          <td>更多并行内存访问以隐藏 DRAM 延迟</td>
          <td>调整 SM 资源的使用，如每个块的线程数、每个块的共享内存和每个线程的寄存器数</td>
      </tr>
      <tr>
          <td>启用合并的全局内存访问</td>
          <td>更少的流水线停顿，等待全局内存访问</td>
          <td>更少的全局内存流量和更好的突发/缓存线利用</td>
          <td>以合并的方式在全局内存和共享内存之间传输，并在共享内存中执行非合并访问（例如，角落转向）</td>
      </tr>
      <tr>
          <td>最小化控制分歧</td>
          <td>高 SIMD 效率（SIMD 执行期间空闲核心更少）</td>
          <td>—</td>
          <td>重新安排线程到工作和/或数据的映射</td>
      </tr>
      <tr>
          <td>分块重用数据</td>
          <td>更少的流水线停顿，等待全局内存访问</td>
          <td>更少的全局内存流量</td>
          <td>将在块内重用的数据放在共享内存或寄存器中，使其仅在全局内存和 SM 之间传输一次</td>
      </tr>
      <tr>
          <td>私有化（稍后介绍）</td>
          <td>更少的流水线停顿，等待原子更新</td>
          <td>更少的原子更新争用和序列化</td>
          <td>将部分更新应用于数据的私有副本，然后在完成时更新通用副本</td>
      </tr>
      <tr>
          <td>线程粗化</td>
          <td>更少的冗余工作、分歧或同步</td>
          <td>更少的冗余全局内存流量</td>
          <td>为每个线程分配多个并行单元，以减少不必要的并行性代价</td>
      </tr>
  </tbody>
</table>
<p>在本书的后面两个部分，将应用表格中的技术来优化并行程序，通过实践来理解和应用。</p>
<h2 id="65-knowing-your-computations-bottleneck-了解性能瓶颈">6.5 Knowing your computation&rsquo;s bottleneck 了解性能瓶颈</h2>
<p>要对计算程序进行优化，必须针对其瓶颈进行，否则可能收效甚微。遗憾的是，本书似乎不会涉及太多有关性能瓶颈识别的内容，作者推荐了 NVIDIA 官方的工具指南：<a href="https://docs.nvidia.com/cuda/profiler-users-guide/">Profiler</a>。</p>
<h1 id="参考">参考</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://gfjiangly.github.io/cpu_parallel/memory_consistency_model.html#RMO%EF%BC%9A%E5%AE%BD%E6%9D%BE%E5%86%85%E5%AD%98%E6%A8%A1%E5%9E%8B">内存一致性模型 | jiang</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>LogSumExp梯度推导</title>
      <link>https://www.zhouxin.space/notes/gradient-of-log-sum-exp/</link>
      <pubDate>Sat, 20 Jul 2024 11:08:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/gradient-of-log-sum-exp/</guid>
      <description>&lt;h1 id=&#34;前言&#34;&gt;前言&lt;/h1&gt;
&lt;p&gt;在 CMU 10-414/714 Deep Learning System 第二个 homework 有一个小任务要对数值稳定形式的 LogSumExp 的梯度进行推导，查阅了不少资料 &lt;sup id=&#34;fnref:1&#34;&gt;&lt;a href=&#34;#fn:1&#34; class=&#34;footnote-ref&#34; role=&#34;doc-noteref&#34;&gt;1&lt;/a&gt;&lt;/sup&gt;，琢磨好半天才搞懂，特此记录。&lt;/p&gt;
&lt;h1 id=&#34;推导过程&#34;&gt;推导过程&lt;/h1&gt;
&lt;h2 id=&#34;符号说明&#34;&gt;符号说明&lt;/h2&gt;
&lt;p&gt;推导过程中使用的符号说明如下：&lt;/p&gt;


&lt;div&gt;$$

\begin{align*}
 z &amp;amp;\in \mathbb{R}^n\\
 z_k &amp;amp;= \max{z}\\
 \hat{z} &amp;amp;= z - \max{z}\\
 f &amp;amp;= \log{\sum_{i=1}^n{\exp{(z_i - \max{z})}}&amp;#43;\max{z}}\\
 &amp;amp;=\log{\sum_{i=1}^n\exp\hat{z}_i}&amp;#43;z_k
\end{align*}

$$&lt;/div&gt;

&lt;h2 id=&#34;非最大情况推导&#34;&gt;非最大情况推导&lt;/h2&gt;
&lt;p&gt;当 $z_j\neq z_k$ 时，$\frac{\partial{f}}{\partial{z_j}}$ 推导如下：&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="前言">前言</h1>
<p>在 CMU 10-414/714 Deep Learning System 第二个 homework 有一个小任务要对数值稳定形式的 LogSumExp 的梯度进行推导，查阅了不少资料 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>，琢磨好半天才搞懂，特此记录。</p>
<h1 id="推导过程">推导过程</h1>
<h2 id="符号说明">符号说明</h2>
<p>推导过程中使用的符号说明如下：</p>


<div>$$

\begin{align*}
 z &amp;\in \mathbb{R}^n\\
 z_k &amp;= \max{z}\\
 \hat{z} &amp;= z - \max{z}\\
 f &amp;= \log{\sum_{i=1}^n{\exp{(z_i - \max{z})}}&#43;\max{z}}\\
 &amp;=\log{\sum_{i=1}^n\exp\hat{z}_i}&#43;z_k
\end{align*}

$$</div>

<h2 id="非最大情况推导">非最大情况推导</h2>
<p>当 $z_j\neq z_k$ 时，$\frac{\partial{f}}{\partial{z_j}}$ 推导如下：</p>


<div>$$

\begin{align*}
\frac{\partial{f}}{\partial{z_j}} 
&amp;=\frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\partial z_j} &#43; \frac{\partial z_k}{\partial{z_j}} \\
&amp;= \frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\sum_{i=1}^n\exp\hat{z}_i}\cdot \frac{\sum_{i=1}^n\exp\hat{z}_i}{\partial{z_j}}&#43;0 \\
&amp;=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot(\sum_{i\neq j} \frac{\partial\exp{\hat z_i}}{\partial z_j}&#43;\frac{\partial \exp{\hat z_j}}{\partial z_j}) \\ 
&amp;=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot(0&#43;\exp{\hat{z}_j}) \\ 
&amp;=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}
\end{align*}

$$</div>

<h2 id="最大情况推导">最大情况推导</h2>
<p>当 $z_j= z_k$ 时，$\frac{\partial{f}}{\partial{z_j}}$ 推导如下：</p>


<div>$$

\begin{align*}
\frac{\partial{f}}{\partial{z_j}} 
&amp;=\frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\partial z_j} &#43; \frac{\partial z_k}{\partial{z_j}} \\
&amp;= \frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\sum_{i=1}^n\exp\hat{z}_i}\cdot \frac{\sum_{i=1}^n\exp\hat{z}_i}{\partial{z_j}}&#43;1 \\
&amp;=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot [\sum_{z_i \neq z_k}{\frac{\partial \exp{(z_i-z_k)}}{\partial z_j}}&#43;\sum_{z_i=z_k}{\frac{\partial \exp{(z_i-z_k)}}{\partial z_j}}]&#43;1\\
&amp;\text{注意，上式中有}z_j=z_k\\
&amp;=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot[\sum_{z_i \neq z_k}{-\exp(z_i-z_k)}&#43;0]&#43;1 \\
&amp;= 1-\frac{\sum_{z_i \neq z_k}{\exp(z_i-z_k)}}{\sum_{i=1}^n\exp\hat{z}_i} \\
&amp;=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}
\end{align*}

$$</div>

<h2 id="一般情况">一般情况</h2>
<p>注意到无论 $z_j$ 是不是最大值，都有：</p>


<div>$$

\frac{\partial{f}}{\partial{z_j}}=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}

$$</div>

<p>这里我们讨论的是 $f\in \mathbb{R}$ 且 $z\in\mathbb{R}^n$ 的情况，实际情况中，$f$ 和 $z$ 都是高维张量，我们要求 $z$ 关于 $z$ 的梯度，即 $\nabla_z f$。</p>
<h1 id="代码实现">代码实现</h1>
<p>首先感谢 <a href="https://github.com/yofufufufu">yofufufufu</a> 的不吝赐教，代码实现主要参考他的解释 <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>。我们继续来化简公式：</p>


<div>$$

\begin{align*}
\frac{\partial{f}}{\partial{z_j}}
&amp;=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}\\
&amp;=\exp(z_j - \log \sum_{i=1}^n\exp\hat{z}_i)\\
&amp;=\exp(z_j - f)
\end{align*}

$$</div>

<p>惊喜地发现，LogSumExp 这个函数的梯度可以用其输入和输出来表示，那在代码实现中，只要获取该节点的输入和输出就可以计算出梯度，即在 cmu10414 课程，该节点实现如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">LogSumExp</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">max_z</span> <span class="o">=</span> <span class="n">array_api</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">array_api</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">array_api</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">Z</span> <span class="o">-</span> <span class="n">max_z</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">))</span> <span class="o">+</span> <span class="n">max_z</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">        <span class="n">z</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="k">else</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">))]</span>
</span></span><span class="line"><span class="cl">        <span class="n">gradient</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="n">node</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="n">gradient</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="参考资料">参考资料</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://blog.csdn.net/u010043946/article/details/134408424">logsumexp 反向传播推导_logsumexp (lse)-CSDN博客</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p><a href="https://github.com/kcxain/dlsys/issues/4#issuecomment-2242385479">hw2 LogSumExp梯度公式推导 · Issue #4 · kcxain/dlsys · GitHub</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>使用ssh远程连接wsl2</title>
      <link>https://www.zhouxin.space/notes/using-ssh-to-connect-remotely-to-wsl2/</link>
      <pubDate>Wed, 17 Jul 2024 17:26:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/using-ssh-to-connect-remotely-to-wsl2/</guid>
      <description>&lt;h1 id=&#34;概述&#34;&gt;概述&lt;/h1&gt;
&lt;p&gt;wsl2 使得 Windows 用户可以很方便地访问 Linux 环境，微软也在 vscode 中提供了相应的插件支持。但 wsl2 一般都是通过本地访问的，微软似乎没有直接提供远程访问 wsl2 的方式。&lt;/p&gt;
&lt;p&gt;经过一番摸索，远程访问 wsl2 主要有以下几个步骤：&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="概述">概述</h1>
<p>wsl2 使得 Windows 用户可以很方便地访问 Linux 环境，微软也在 vscode 中提供了相应的插件支持。但 wsl2 一般都是通过本地访问的，微软似乎没有直接提供远程访问 wsl2 的方式。</p>
<p>经过一番摸索，远程访问 wsl2 主要有以下几个步骤：</p>
<ul>
<li>【非必需】启用 windows 中的 ssh 服务器</li>
<li>启用并配置 wsl2 中的 ssh 服务</li>
<li>开放防火墙</li>
<li>修改 wsl2 网络模式</li>
</ul>
<h1 id="详细步骤">详细步骤</h1>
<h2 id="非必需启用-windows-中的-ssh-服务器">【非必需】启用 windows 中的 ssh 服务器</h2>
<p>在摸索过程中发现，windows 也是支持通过 ssh 远程连接的，想要 ssh 到 wsl2，自然就有一种曲线救国的方案，即先通过 ssh 连接到 windows 宿主机，然后通过终端进入 wsl2。理论可行，实践如下：</p>
<ul>
<li>启用 ssh 服务器<br>
windows 中 ssh 服务器启用可参考官方文档 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>，写的很详细。以 Windows 11 为例，在 powershell【使用<strong>系统默认版本</strong>，powershell 7.4.3 无法正确执行】中以管理员身份执行以下命令即可启用 ssh 服务器：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-powershell" data-lang="powershell"><span class="line"><span class="cl"><span class="c"># 安装OpenSSH客户端</span>
</span></span><span class="line"><span class="cl"><span class="nb">Add-WindowsCapability</span> <span class="n">-Online</span> <span class="n">-Name</span> <span class="n">OpenSSH</span><span class="p">.</span><span class="n">Client</span><span class="p">~~~~</span><span class="mf">0.0</span><span class="p">.</span><span class="py">1</span><span class="p">.</span><span class="py">0</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># 安装OpenSSH服务器</span>
</span></span><span class="line"><span class="cl"><span class="nb">Add-WindowsCapability</span> <span class="n">-Online</span> <span class="n">-Name</span> <span class="n">OpenSSH</span><span class="p">.</span><span class="n">Server</span><span class="p">~~~~</span><span class="mf">0.0</span><span class="p">.</span><span class="py">1</span><span class="p">.</span><span class="py">0</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># 启用sshd服务</span>
</span></span><span class="line"><span class="cl"><span class="nb">Start-Service</span> <span class="n">sshd</span>
</span></span><span class="line"><span class="cl"><span class="nb">Set-Service</span> <span class="n">-Name</span> <span class="n">sshd</span> <span class="n">-StartupType</span> <span class="s1">&#39;Automatic&#39;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># 确认防火墙规则被自动配置</span>
</span></span><span class="line"><span class="cl"><span class="k">if</span> <span class="p">(!(</span><span class="nb">Get-NetFirewallRule</span> <span class="n">-Name</span> <span class="s2">&#34;OpenSSH-Server-In-TCP&#34;</span> <span class="n">-ErrorAction</span> <span class="n">SilentlyContinue</span> <span class="p">|</span> <span class="nb">Select-Object</span> <span class="n">Name</span><span class="p">,</span> <span class="n">Enabled</span><span class="p">))</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nb">Write-Output</span> <span class="s2">&#34;Firewall Rule &#39;OpenSSH-Server-In-TCP&#39; does not exist, creating it...&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="nb">New-NetFirewallRule</span> <span class="n">-Name</span> <span class="s1">&#39;OpenSSH-Server-In-TCP&#39;</span> <span class="n">-DisplayName</span> <span class="s1">&#39;OpenSSH Server (sshd)&#39;</span> <span class="n">-Enabled</span> <span class="n">True</span> <span class="n">-Direction</span> <span class="n">Inbound</span> <span class="n">-Protocol</span> <span class="n">TCP</span> <span class="n">-Action</span> <span class="n">Allow</span> <span class="n">-LocalPort</span> <span class="mf">22</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nb">Write-Output</span> <span class="s2">&#34;Firewall rule &#39;OpenSSH-Server-In-TCP&#39; has been created and exists.&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>执行完毕后，可使用 <code>ssh &lt;username&gt;@127.0.0.1</code> 测试能否通过 ssh 连接到 windows 终端。注意，对于 Windows OpenSSH，唯一可用的身份验证方法是 <code>password</code> 和 <code>publickey</code>，即不支持通过 Microsoft 账号验证。</p>
<ul>
<li>修改默认终端为 powershell<br>
在 windows 中，默认连接的终端为 cmd，可使用命令 <code>echo %COMSPEC%</code> 确认。默认使用的终端由注册表中 <code>HKEY_LOCAL_MACHINE\SOFTWARE\OpenSSH\DefaultShell</code> 决定，使用如下命令可以将其修改为 powershell：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-powershell" data-lang="powershell"><span class="line"><span class="cl"><span class="nv">$pwshPath</span> <span class="p">=</span> <span class="p">(</span><span class="nb">Get-Command</span> <span class="n">powershell</span><span class="p">.</span><span class="n">exe</span><span class="p">).</span><span class="py">Source</span>
</span></span><span class="line"><span class="cl"><span class="nv">$pwshPathQuoted</span> <span class="p">=</span> <span class="s1">&#39;&#34;&#39;</span> <span class="p">+</span> <span class="nv">$pwshPath</span> <span class="p">+</span> <span class="s1">&#39;&#34;&#39;</span>
</span></span><span class="line"><span class="cl"><span class="n">sudo</span> <span class="nb">Set-ItemProperty</span> <span class="n">-Verbose</span> <span class="n">-Path</span> <span class="s2">&#34;HKLM:\SOFTWARE\OpenSSH&#34;</span> <span class="n">-Name</span> <span class="s2">&#34;DefaultShell&#34;</span> <span class="n">-Value</span> <span class="nv">$pwshPathQuoted</span> <span class="n">-Force</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意，默认终端修改为 powershell 7（pwsh.exe）有权限不足的报错，这是因为 pwsh.exe 默认安装在 <code>C\programs files</code> 路径下，该路径需要管理员权限访问。</p>
<ul>
<li>ssh 安全配置<br>
Windows ssh 服务器的默认配置文件为 <code>%programdata%\ssh\sshd_config</code>，各字段含义参考官方文档 <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup><sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup>，建议修改默认端口，并使用 <code>AllowUsers</code> 指定允许连接的用户，或者使用 <code>AllowGroups</code> 指定远程连接用户组连接。在配置文件中追加如下内容：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">Port xxxxx # 修改默认端口
</span></span><span class="line"><span class="cl">AllowGroups &#34;sshUsers&#34; # 仅允许指定组
</span></span></code></pre></td></tr></table>
</div>
</div><p>在上述配置文件中，我们仅允许了 sshUsers 组内用户进行连接，接下来我们创建一个 sshUsers 组，并添加相应成员：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-powershell" data-lang="powershell"><span class="line"><span class="cl"><span class="nb">restart-Service</span> <span class="n">sshd</span> <span class="c"># 修改配置文件后，重启服务才能生效</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">net</span> <span class="n">localgroup</span> <span class="n">sshUsers</span> <span class="p">/</span><span class="n">add</span> <span class="c"># 添加sshUsers组</span>
</span></span><span class="line"><span class="cl"><span class="n">net</span> <span class="n">localgroup</span> <span class="n">sshUsers</span> <span class="p">&lt;</span><span class="n">username</span><span class="p">&gt;</span> <span class="p">/</span><span class="n">add</span> <span class="c"># 将user添加到该组</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>此外，还要在防火墙中开放修改的 ssh 服务端口：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-powershell" data-lang="powershell"><span class="line"><span class="cl"><span class="nb">New-NetFirewallRule</span> <span class="n">-DisplayName</span> <span class="s1">&#39;&#34;Allow SSH on Port xxxxx&#34;&#39;</span> <span class="n">-Direction</span> <span class="n">Inbound</span> <span class="n">-Protocol</span> <span class="n">TCP</span> <span class="n">-LocalPort</span> <span class="n">xxxxx</span> <span class="n">-Action</span> <span class="n">Allow</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="启用并配置-wsl2-中的-ssh-服务">启用并配置 wsl2 中的 ssh 服务</h2>
<ul>
<li>安装/重装 OpenSSH 服务器<br>
无论 wsl2 中是否已经安装好 OpenSSH 服务器，都建议卸载后重装，即执行如下命令：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl"><span class="c1"># 先卸载重装系统自带的sshd</span>
</span></span><span class="line"><span class="cl">sudo apt-get remove openssh-server
</span></span><span class="line"><span class="cl">sudo apt-get install openssh-server
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>ssh 安全配置<br>
wsl2 ssh 服务器默认配置文件为 <code>/etc/ssh/sshd_config</code>，各字段含义参考官方文档 <sup id="fnref1:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>，建议修改默认端口，并通过密钥认证登录，即在配置中修改如下配置项内容：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-fallback" data-lang="fallback"><span class="line"><span class="cl">Port xxxxx # 修改默认端口
</span></span><span class="line"><span class="cl">PasswordAuthentication no # 禁用密码认证
</span></span><span class="line"><span class="cl">PubkeyAuthentication yes # 允许公钥认证
</span></span><span class="line"><span class="cl">AuthenticationMethods publickey # 仅使用公钥认证
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后将宿主机和其它需要连接到 wsl2 设备的公钥写入 <code>~/.ssh/authorized_keys</code> 文件。</p>
<p>修改 <code>sshd_config</code> 配置文件后，需要使用命令 <code>sudo service sshd restart</code> 重启服务才会生效。写入公钥后在 windows 宿主机上就可以使用 <code>ssh &lt;username&gt;@127.0.0.1 -p xxxxx</code> 测试能否连接到 wsl2。</p>
<h2 id="开放防火墙">开放防火墙</h2>
<p>修改端口后，需要在宿主机的防火墙中开放对应的端口，在宿主机的 powershell 中以管理员权限执行如下命令：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-powershell" data-lang="powershell"><span class="line"><span class="cl"><span class="nb">New-NetFirewallRule</span> <span class="n">-DisplayName</span> <span class="s1">&#39;&#34;Allow SSH on Port xxxxx&#34;&#39;</span> <span class="n">-Direction</span> <span class="n">Inbound</span> <span class="n">-Protocol</span> <span class="n">TCP</span> <span class="n">-LocalPort</span> <span class="n">xxxxx</span> <span class="n">-Action</span> <span class="n">Allow</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="修改-wsl2-网络模式">修改 wsl2 网络模式</h2>
<p>wsl2 的默认网络模式是 NAT<sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup>，在此模式下：</p>
<ul>
<li>windows 可以使用 localhost 访问 wsl2 网络应用</li>
<li>wsl2 需要通过获取主机 ip 访问 windows 应用</li>
<li>局域网设备需要通过主机端口转发访问 wsl2 应用</li>
</ul>
<p>在运行 Windows 11 22H2 及更高版本的宿主机上，wsl2 支持镜像网络模式，在此模式下，windows 主机可以使用 localhost 访问 wsl2 网络应用，局域网设备可以直接使用宿主机 ip 访问 wsl2 网络应用。</p>
<p>wsl2 配置文件路径为 <code>%UserProfile%/.wslconfig</code>，修改为以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-fallback" data-lang="fallback"><span class="line"><span class="cl">[experimental]
</span></span><span class="line"><span class="cl">networkingMode=mirrored
</span></span><span class="line"><span class="cl">dnsTunneling=true
</span></span><span class="line"><span class="cl">firewall=true
</span></span><span class="line"><span class="cl">autoProxy=true
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述配置中还启用了自动代理、防火墙和 dns 隧道。修改完成后，重启 wsl 即可应用该配置：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-powershell" data-lang="powershell"><span class="line"><span class="cl"><span class="n">wsl</span> <span class="p">-</span><span class="n">-shutdown</span>
</span></span><span class="line"><span class="cl"><span class="n">wsl</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>至此，我们就可以在局域网内使用 ssh 连接宿主机上的 wsl2，如果想在外网连接，可以使用 zerotier 异地组网，可参考文章 <a href="https://www.zhouxin.space/notes/setup-zerotier-moon-server/">搭建ZeroTier MOON服务器 | 周鑫的个人博客</a>。</p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://learn.microsoft.com/zh-cn/windows-server/administration/openssh/openssh_install_firstuse?tabs=powershell">适用于 Windows 的 OpenSSH 入门 | Microsoft Learn</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p><a href="https://linux.die.net/man/5/sshd_config">sshd_config(5): OpenSSH SSH daemon config file - Linux man page</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a>&#160;<a href="#fnref1:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p><a href="https://learn.microsoft.com/zh-cn/windows-server/administration/openssh/openssh_server_configuration#windows-configurations-in-sshd_config">适用于 Windows 的 OpenSSH 服务器配置 | Microsoft Learn</a>&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p><a href="https://learn.microsoft.com/zh-cn/windows/wsl/networking#default-networking-mode-nat">使用 WSL 访问网络应用程序 | Microsoft Learn</a>&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>在PaperMod中引入侧边目录和阅读进度显示</title>
      <link>https://www.zhouxin.space/logs/introduce-side-toc-and-reading-percentage-to-papermod/</link>
      <pubDate>Mon, 08 Jul 2024 20:04:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/logs/introduce-side-toc-and-reading-percentage-to-papermod/</guid>
      <description>&lt;h1 id=&#34;概述&#34;&gt;概述&lt;/h1&gt;
&lt;p&gt;在 PaperMod 中，目录的默认行为是在文章前展示，在阅读过程中无法利用其帮助定位或者精确跳转到某一部分，侧边目录能够很好解决上述痛点。此外，阅读进度百分比也能够帮助读者定位阅读位置，还能让网页显得更灵动一点。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="概述">概述</h1>
<p>在 PaperMod 中，目录的默认行为是在文章前展示，在阅读过程中无法利用其帮助定位或者精确跳转到某一部分，侧边目录能够很好解决上述痛点。此外，阅读进度百分比也能够帮助读者定位阅读位置，还能让网页显得更灵动一点。</p>
<p>实现方案主要借鉴自 <a href="https://www.sulvblog.cn/">Sulv&rsquo;s Blog</a>，其中侧边目录其博文 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> 介绍的方法对长目录支持不太友好，不会自动滚动到正在阅读的内容，本文对此进行了改进。百分比显示实现的方案来自其博客的 <a href="https://github.com/xyming108/sulv-hugo-papermod">源码</a>。</p>
<p>实现效果如下图所示：</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407082117402.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h1 id="步骤">步骤</h1>
<h2 id="侧边目录">侧边目录</h2>
<p>在 PaperMod 中，目录相关的 html 代码定义在 <code>layouts/partials/toc.html</code> 中，为了修改它，只要创建一个 <code>&lt;your_hugo_site&gt;/layouts/partials/toc.html</code> 覆盖即可，在其中粘贴如下代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">  1
</span><span class="lnt">  2
</span><span class="lnt">  3
</span><span class="lnt">  4
</span><span class="lnt">  5
</span><span class="lnt">  6
</span><span class="lnt">  7
</span><span class="lnt">  8
</span><span class="lnt">  9
</span><span class="lnt"> 10
</span><span class="lnt"> 11
</span><span class="lnt"> 12
</span><span class="lnt"> 13
</span><span class="lnt"> 14
</span><span class="lnt"> 15
</span><span class="lnt"> 16
</span><span class="lnt"> 17
</span><span class="lnt"> 18
</span><span class="lnt"> 19
</span><span class="lnt"> 20
</span><span class="lnt"> 21
</span><span class="lnt"> 22
</span><span class="lnt"> 23
</span><span class="lnt"> 24
</span><span class="lnt"> 25
</span><span class="lnt"> 26
</span><span class="lnt"> 27
</span><span class="lnt"> 28
</span><span class="lnt"> 29
</span><span class="lnt"> 30
</span><span class="lnt"> 31
</span><span class="lnt"> 32
</span><span class="lnt"> 33
</span><span class="lnt"> 34
</span><span class="lnt"> 35
</span><span class="lnt"> 36
</span><span class="lnt"> 37
</span><span class="lnt"> 38
</span><span class="lnt"> 39
</span><span class="lnt"> 40
</span><span class="lnt"> 41
</span><span class="lnt"> 42
</span><span class="lnt"> 43
</span><span class="lnt"> 44
</span><span class="lnt"> 45
</span><span class="lnt"> 46
</span><span class="lnt"> 47
</span><span class="lnt"> 48
</span><span class="lnt"> 49
</span><span class="lnt"> 50
</span><span class="lnt"> 51
</span><span class="lnt"> 52
</span><span class="lnt"> 53
</span><span class="lnt"> 54
</span><span class="lnt"> 55
</span><span class="lnt"> 56
</span><span class="lnt"> 57
</span><span class="lnt"> 58
</span><span class="lnt"> 59
</span><span class="lnt"> 60
</span><span class="lnt"> 61
</span><span class="lnt"> 62
</span><span class="lnt"> 63
</span><span class="lnt"> 64
</span><span class="lnt"> 65
</span><span class="lnt"> 66
</span><span class="lnt"> 67
</span><span class="lnt"> 68
</span><span class="lnt"> 69
</span><span class="lnt"> 70
</span><span class="lnt"> 71
</span><span class="lnt"> 72
</span><span class="lnt"> 73
</span><span class="lnt"> 74
</span><span class="lnt"> 75
</span><span class="lnt"> 76
</span><span class="lnt"> 77
</span><span class="lnt"> 78
</span><span class="lnt"> 79
</span><span class="lnt"> 80
</span><span class="lnt"> 81
</span><span class="lnt"> 82
</span><span class="lnt"> 83
</span><span class="lnt"> 84
</span><span class="lnt"> 85
</span><span class="lnt"> 86
</span><span class="lnt"> 87
</span><span class="lnt"> 88
</span><span class="lnt"> 89
</span><span class="lnt"> 90
</span><span class="lnt"> 91
</span><span class="lnt"> 92
</span><span class="lnt"> 93
</span><span class="lnt"> 94
</span><span class="lnt"> 95
</span><span class="lnt"> 96
</span><span class="lnt"> 97
</span><span class="lnt"> 98
</span><span class="lnt"> 99
</span><span class="lnt">100
</span><span class="lnt">101
</span><span class="lnt">102
</span><span class="lnt">103
</span><span class="lnt">104
</span><span class="lnt">105
</span><span class="lnt">106
</span><span class="lnt">107
</span><span class="lnt">108
</span><span class="lnt">109
</span><span class="lnt">110
</span><span class="lnt">111
</span><span class="lnt">112
</span><span class="lnt">113
</span><span class="lnt">114
</span><span class="lnt">115
</span><span class="lnt">116
</span><span class="lnt">117
</span><span class="lnt">118
</span><span class="lnt">119
</span><span class="lnt">120
</span><span class="lnt">121
</span><span class="lnt">122
</span><span class="lnt">123
</span><span class="lnt">124
</span><span class="lnt">125
</span><span class="lnt">126
</span><span class="lnt">127
</span><span class="lnt">128
</span><span class="lnt">129
</span><span class="lnt">130
</span><span class="lnt">131
</span><span class="lnt">132
</span><span class="lnt">133
</span><span class="lnt">134
</span><span class="lnt">135
</span><span class="lnt">136
</span><span class="lnt">137
</span><span class="lnt">138
</span><span class="lnt">139
</span><span class="lnt">140
</span><span class="lnt">141
</span><span class="lnt">142
</span><span class="lnt">143
</span><span class="lnt">144
</span><span class="lnt">145
</span><span class="lnt">146
</span><span class="lnt">147
</span><span class="lnt">148
</span><span class="lnt">149
</span><span class="lnt">150
</span><span class="lnt">151
</span><span class="lnt">152
</span><span class="lnt">153
</span><span class="lnt">154
</span><span class="lnt">155
</span><span class="lnt">156
</span><span class="lnt">157
</span><span class="lnt">158
</span><span class="lnt">159
</span><span class="lnt">160
</span><span class="lnt">161
</span><span class="lnt">162
</span><span class="lnt">163
</span><span class="lnt">164
</span><span class="lnt">165
</span><span class="lnt">166
</span><span class="lnt">167
</span><span class="lnt">168
</span><span class="lnt">169
</span><span class="lnt">170
</span><span class="lnt">171
</span><span class="lnt">172
</span><span class="lnt">173
</span><span class="lnt">174
</span><span class="lnt">175
</span><span class="lnt">176
</span><span class="lnt">177
</span><span class="lnt">178
</span><span class="lnt">179
</span><span class="lnt">180
</span><span class="lnt">181
</span><span class="lnt">182
</span><span class="lnt">183
</span><span class="lnt">184
</span><span class="lnt">185
</span><span class="lnt">186
</span><span class="lnt">187
</span><span class="lnt">188
</span><span class="lnt">189
</span><span class="lnt">190
</span><span class="lnt">191
</span><span class="lnt">192
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl">{{- $headers := findRE &#34;<span class="p">&lt;</span><span class="nt">h</span><span class="err">[</span><span class="na">1-6</span><span class="err">].*?</span><span class="p">&gt;</span>(.|\n])+?<span class="err">&lt;</span>/h[1-6]&gt;&#34; .Content -}}
</span></span><span class="line"><span class="cl">{{- $has_headers := ge (len $headers) 1 -}}
</span></span><span class="line"><span class="cl">{{- if $has_headers -}}
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">aside</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;toc-container&#34;</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;toc-container wide&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;</span><span class="nt">div</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;toc&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">        <span class="p">&lt;</span><span class="nt">details</span> <span class="err">{{</span><span class="na">if</span> <span class="err">(.</span><span class="na">Param</span> <span class="err">&#34;</span><span class="na">TocOpen</span><span class="err">&#34;)</span> <span class="err">}}</span> <span class="na">open</span><span class="err">{{</span> <span class="na">end</span> <span class="err">}}</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">            <span class="p">&lt;</span><span class="nt">summary</span> <span class="na">accesskey</span><span class="o">=</span><span class="s">&#34;c&#34;</span> <span class="na">title</span><span class="o">=</span><span class="s">&#34;(Alt + C)&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                <span class="p">&lt;</span><span class="nt">span</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;details&#34;</span><span class="p">&gt;</span>{{- i18n &#34;toc&#34; | default &#34;Table of Contents&#34; }}<span class="p">&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">            <span class="p">&lt;/</span><span class="nt">summary</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">            <span class="p">&lt;</span><span class="nt">div</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;inner&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                {{- $largest := 6 -}}
</span></span><span class="line"><span class="cl">                {{- range $headers -}}
</span></span><span class="line"><span class="cl">                {{- $headerLevel := index (findRE &#34;[1-6]&#34; . 1) 0 -}}
</span></span><span class="line"><span class="cl">                {{- $headerLevel := len (seq $headerLevel) -}}
</span></span><span class="line"><span class="cl">                {{- if lt $headerLevel $largest -}}
</span></span><span class="line"><span class="cl">                {{- $largest = $headerLevel -}}
</span></span><span class="line"><span class="cl">                {{- end -}}
</span></span><span class="line"><span class="cl">                {{- end -}}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">                {{- $firstHeaderLevel := len (seq (index (findRE &#34;[1-6]&#34; (index $headers 0) 1) 0)) -}}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">                {{- $.Scratch.Set &#34;bareul&#34; slice -}}
</span></span><span class="line"><span class="cl">                <span class="p">&lt;</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                    {{- range seq (sub $firstHeaderLevel $largest) -}}
</span></span><span class="line"><span class="cl">                    <span class="p">&lt;</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                        {{- $.Scratch.Add &#34;bareul&#34; (sub (add $largest .) 1) -}}
</span></span><span class="line"><span class="cl">                        {{- end -}}
</span></span><span class="line"><span class="cl">                        {{- range $i, $header := $headers -}}
</span></span><span class="line"><span class="cl">                        {{- $headerLevel := index (findRE &#34;[1-6]&#34; . 1) 0 -}}
</span></span><span class="line"><span class="cl">                        {{- $headerLevel := len (seq $headerLevel) -}}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">                        {{/* get id=&#34;xyz&#34; */}}
</span></span><span class="line"><span class="cl">                        {{- $id := index (findRE &#34;(id=\&#34;(.*?)\&#34;)&#34; $header 9) 0 }}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">                        {{- /* strip id=&#34;&#34; to leave xyz, no way to get regex capturing groups in hugo */ -}}
</span></span><span class="line"><span class="cl">                        {{- $cleanedID := replace (replace $id &#34;id=\&#34;&#34; &#34;&#34;) &#34;\&#34;&#34; &#34;&#34; }}
</span></span><span class="line"><span class="cl">                        {{- $header := replaceRE &#34;<span class="p">&lt;</span><span class="nt">h</span><span class="err">[</span><span class="na">1-6</span><span class="err">].*?</span><span class="p">&gt;</span>((.|\n])+?)<span class="err">&lt;</span>/h[1-6]&gt;&#34; &#34;$1&#34; $header -}}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">                        {{- if ne $i 0 -}}
</span></span><span class="line"><span class="cl">                        {{- $prevHeaderLevel := index (findRE &#34;[1-6]&#34; (index $headers (sub $i 1)) 1) 0 -}}
</span></span><span class="line"><span class="cl">                        {{- $prevHeaderLevel := len (seq $prevHeaderLevel) -}}
</span></span><span class="line"><span class="cl">                        {{- if gt $headerLevel $prevHeaderLevel -}}
</span></span><span class="line"><span class="cl">                        {{- range seq $prevHeaderLevel (sub $headerLevel 1) -}}
</span></span><span class="line"><span class="cl">                        <span class="p">&lt;</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                            {{/* the first should not be recorded */}}
</span></span><span class="line"><span class="cl">                            {{- if ne $prevHeaderLevel . -}}
</span></span><span class="line"><span class="cl">                            {{- $.Scratch.Add &#34;bareul&#34; . -}}
</span></span><span class="line"><span class="cl">                            {{- end -}}
</span></span><span class="line"><span class="cl">                            {{- end -}}
</span></span><span class="line"><span class="cl">                            {{- else -}}
</span></span><span class="line"><span class="cl">                            <span class="p">&lt;/</span><span class="nt">li</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                            {{- if lt $headerLevel $prevHeaderLevel -}}
</span></span><span class="line"><span class="cl">                            {{- range seq (sub $prevHeaderLevel 1) -1 $headerLevel -}}
</span></span><span class="line"><span class="cl">                            {{- if in ($.Scratch.Get &#34;bareul&#34;) . -}}
</span></span><span class="line"><span class="cl">                        <span class="p">&lt;/</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                        {{/* manually do pop item */}}
</span></span><span class="line"><span class="cl">                        {{- $tmp := $.Scratch.Get &#34;bareul&#34; -}}
</span></span><span class="line"><span class="cl">                        {{- $.Scratch.Delete &#34;bareul&#34; -}}
</span></span><span class="line"><span class="cl">                        {{- $.Scratch.Set &#34;bareul&#34; slice}}
</span></span><span class="line"><span class="cl">                        {{- range seq (sub (len $tmp) 1) -}}
</span></span><span class="line"><span class="cl">                        {{- $.Scratch.Add &#34;bareul&#34; (index $tmp (sub . 1)) -}}
</span></span><span class="line"><span class="cl">                        {{- end -}}
</span></span><span class="line"><span class="cl">                        {{- else -}}
</span></span><span class="line"><span class="cl">                    <span class="p">&lt;/</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                    <span class="p">&lt;/</span><span class="nt">li</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                    {{- end -}}
</span></span><span class="line"><span class="cl">                    {{- end -}}
</span></span><span class="line"><span class="cl">                    {{- end -}}
</span></span><span class="line"><span class="cl">                    {{- end }}
</span></span><span class="line"><span class="cl">                    <span class="p">&lt;</span><span class="nt">li</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                        <span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;#{{- $cleanedID -}}&#34;</span> <span class="na">aria-label</span><span class="o">=</span><span class="s">&#34;{{- $header | plainify -}}&#34;</span><span class="p">&gt;</span>{{- $header | safeHTML -}}<span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                        {{- else }}
</span></span><span class="line"><span class="cl">                    <span class="p">&lt;</span><span class="nt">li</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                        <span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;#{{- $cleanedID -}}&#34;</span> <span class="na">aria-label</span><span class="o">=</span><span class="s">&#34;{{- $header | plainify -}}&#34;</span><span class="p">&gt;</span>{{- $header | safeHTML -}}<span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                        {{- end -}}
</span></span><span class="line"><span class="cl">                        {{- end -}}
</span></span><span class="line"><span class="cl">                        <span class="c">&lt;!-- {{- $firstHeaderLevel := len (seq (index (findRE &#34;[1-6]&#34; (index $headers 0) 1) 0)) -}} --&gt;</span>
</span></span><span class="line"><span class="cl">                        {{- $firstHeaderLevel := $largest }}
</span></span><span class="line"><span class="cl">                        {{- $lastHeaderLevel := len (seq (index (findRE &#34;[1-6]&#34; (index $headers (sub (len $headers) 1)) 1) 0)) }}
</span></span><span class="line"><span class="cl">                    <span class="p">&lt;/</span><span class="nt">li</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                    {{- range seq (sub $lastHeaderLevel $firstHeaderLevel) -}}
</span></span><span class="line"><span class="cl">                    {{- if in ($.Scratch.Get &#34;bareul&#34;) (add . $firstHeaderLevel) }}
</span></span><span class="line"><span class="cl">                <span class="p">&lt;/</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                {{- else }}
</span></span><span class="line"><span class="cl">                <span class="p">&lt;/</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                <span class="p">&lt;/</span><span class="nt">li</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">                {{- end -}}
</span></span><span class="line"><span class="cl">                {{- end }}
</span></span><span class="line"><span class="cl">                <span class="p">&lt;/</span><span class="nt">ul</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">            <span class="p">&lt;/</span><span class="nt">div</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">        <span class="p">&lt;/</span><span class="nt">details</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;/</span><span class="nt">div</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">aside</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="kd">let</span> <span class="nx">activeElement</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kd">let</span> <span class="nx">elements</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="nb">document</span><span class="p">.</span><span class="nx">addEventListener</span><span class="p">(</span><span class="s1">&#39;DOMContentLoaded&#39;</span><span class="p">,</span> <span class="kd">function</span> <span class="p">(</span><span class="nx">event</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nx">checkTocPosition</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">        <span class="nx">elements</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">querySelectorAll</span><span class="p">(</span><span class="s1">&#39;h1[id],h2[id],h3[id],h4[id],h5[id],h6[id]&#39;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="nx">elements</span><span class="p">.</span><span class="nx">length</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="c1">// Make the first header active
</span></span></span><span class="line"><span class="cl">            <span class="nx">activeElement</span> <span class="o">=</span> <span class="nx">elements</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="kr">const</span> <span class="nx">id</span> <span class="o">=</span> <span class="nb">encodeURI</span><span class="p">(</span><span class="nx">activeElement</span><span class="p">.</span><span class="nx">getAttribute</span><span class="p">(</span><span class="s1">&#39;id&#39;</span><span class="p">)).</span><span class="nx">toLowerCase</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">            <span class="nb">document</span><span class="p">.</span><span class="nx">querySelector</span><span class="p">(</span><span class="sb">`.inner ul li a[href=&#34;#</span><span class="si">${</span><span class="nx">id</span><span class="si">}</span><span class="sb">&#34;]`</span><span class="p">).</span><span class="nx">classList</span><span class="p">.</span><span class="nx">add</span><span class="p">(</span><span class="s1">&#39;active&#39;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">        <span class="c1">// Add event listener for the &#34;back to top&#34; link
</span></span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">topLink</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">getElementById</span><span class="p">(</span><span class="s1">&#39;top-link&#39;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="nx">topLink</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nx">topLink</span><span class="p">.</span><span class="nx">addEventListener</span><span class="p">(</span><span class="s1">&#39;click&#39;</span><span class="p">,</span> <span class="p">(</span><span class="nx">event</span><span class="p">)</span> <span class="p">=&gt;</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="c1">// Prevent the default action
</span></span></span><span class="line"><span class="cl">                <span class="nx">event</span><span class="p">.</span><span class="nx">preventDefault</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">                <span class="c1">// Smooth scroll to the top
</span></span></span><span class="line"><span class="cl">                <span class="nb">window</span><span class="p">.</span><span class="nx">scrollTo</span><span class="p">({</span> <span class="nx">top</span><span class="o">:</span> <span class="mi">0</span><span class="p">,</span> <span class="nx">behavior</span><span class="o">:</span> <span class="s1">&#39;smooth&#39;</span> <span class="p">});</span>
</span></span><span class="line"><span class="cl">            <span class="p">});</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span> <span class="kc">false</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="nb">window</span><span class="p">.</span><span class="nx">addEventListener</span><span class="p">(</span><span class="s1">&#39;resize&#39;</span><span class="p">,</span> <span class="kd">function</span><span class="p">(</span><span class="nx">event</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nx">checkTocPosition</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span> <span class="kc">false</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="nb">window</span><span class="p">.</span><span class="nx">addEventListener</span><span class="p">(</span><span class="s1">&#39;scroll&#39;</span><span class="p">,</span> <span class="p">()</span> <span class="p">=&gt;</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// Get the current scroll position
</span></span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">scrollPosition</span> <span class="o">=</span> <span class="nb">window</span><span class="p">.</span><span class="nx">pageYOffset</span> <span class="o">||</span> <span class="nb">document</span><span class="p">.</span><span class="nx">documentElement</span><span class="p">.</span><span class="nx">scrollTop</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">        <span class="c1">// Check if the scroll position is at the top of the page
</span></span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="nx">scrollPosition</span> <span class="o">===</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">        <span class="c1">// Ensure elements is a valid NodeList
</span></span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="nx">elements</span> <span class="o">&amp;&amp;</span> <span class="nx">elements</span><span class="p">.</span><span class="nx">length</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="c1">// Check if there is an object in the top half of the screen or keep the last item active
</span></span></span><span class="line"><span class="cl">            <span class="nx">activeElement</span> <span class="o">=</span> <span class="nb">Array</span><span class="p">.</span><span class="nx">from</span><span class="p">(</span><span class="nx">elements</span><span class="p">).</span><span class="nx">find</span><span class="p">((</span><span class="nx">element</span><span class="p">)</span> <span class="p">=&gt;</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="p">((</span><span class="nx">getOffsetTop</span><span class="p">(</span><span class="nx">element</span><span class="p">)</span> <span class="o">-</span> <span class="nx">scrollPosition</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> 
</span></span><span class="line"><span class="cl">                    <span class="p">(</span><span class="nx">getOffsetTop</span><span class="p">(</span><span class="nx">element</span><span class="p">)</span> <span class="o">-</span> <span class="nx">scrollPosition</span><span class="p">)</span> <span class="o">&lt;</span> <span class="nb">window</span><span class="p">.</span><span class="nx">innerHeight</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="k">return</span> <span class="nx">element</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">})</span> <span class="o">||</span> <span class="nx">activeElement</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">            <span class="nx">elements</span><span class="p">.</span><span class="nx">forEach</span><span class="p">(</span><span class="nx">element</span> <span class="p">=&gt;</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="kr">const</span> <span class="nx">id</span> <span class="o">=</span> <span class="nb">encodeURI</span><span class="p">(</span><span class="nx">element</span><span class="p">.</span><span class="nx">getAttribute</span><span class="p">(</span><span class="s1">&#39;id&#39;</span><span class="p">)).</span><span class="nx">toLowerCase</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">                <span class="kr">const</span> <span class="nx">tocLink</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">querySelector</span><span class="p">(</span><span class="sb">`.inner ul li a[href=&#34;#</span><span class="si">${</span><span class="nx">id</span><span class="si">}</span><span class="sb">&#34;]`</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="p">(</span><span class="nx">element</span> <span class="o">===</span> <span class="nx">activeElement</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">                    <span class="nx">tocLink</span><span class="p">.</span><span class="nx">classList</span><span class="p">.</span><span class="nx">add</span><span class="p">(</span><span class="s1">&#39;active&#39;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">                    <span class="c1">// Ensure the active element is in view within the .inner container
</span></span></span><span class="line"><span class="cl">                    <span class="kr">const</span> <span class="nx">tocContainer</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">querySelector</span><span class="p">(</span><span class="s1">&#39;.toc .inner&#39;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">                    <span class="kr">const</span> <span class="nx">linkOffsetTop</span> <span class="o">=</span> <span class="nx">tocLink</span><span class="p">.</span><span class="nx">offsetTop</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                    <span class="kr">const</span> <span class="nx">containerHeight</span> <span class="o">=</span> <span class="nx">tocContainer</span><span class="p">.</span><span class="nx">clientHeight</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                    <span class="kr">const</span> <span class="nx">linkHeight</span> <span class="o">=</span> <span class="nx">tocLink</span><span class="p">.</span><span class="nx">clientHeight</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">                    <span class="c1">// Calculate the scroll position to center the active link
</span></span></span><span class="line"><span class="cl">                    <span class="kr">const</span> <span class="nx">scrollPosition</span> <span class="o">=</span> <span class="nx">linkOffsetTop</span> <span class="o">-</span> <span class="p">(</span><span class="nx">containerHeight</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="nx">linkHeight</span> <span class="o">/</span> <span class="mi">2</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">                    <span class="nx">tocContainer</span><span class="p">.</span><span class="nx">scrollTo</span><span class="p">({</span> <span class="nx">top</span><span class="o">:</span> <span class="nx">scrollPosition</span><span class="p">,</span> <span class="nx">behavior</span><span class="o">:</span> <span class="s1">&#39;smooth&#39;</span> <span class="p">});</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="nx">tocLink</span><span class="p">.</span><span class="nx">classList</span><span class="p">.</span><span class="nx">remove</span><span class="p">(</span><span class="s1">&#39;active&#39;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">});</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span> <span class="kc">false</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="kr">const</span> <span class="nx">main</span> <span class="o">=</span> <span class="nb">parseInt</span><span class="p">(</span><span class="nx">getComputedStyle</span><span class="p">(</span><span class="nb">document</span><span class="p">.</span><span class="nx">body</span><span class="p">).</span><span class="nx">getPropertyValue</span><span class="p">(</span><span class="s1">&#39;--article-width&#39;</span><span class="p">),</span> <span class="mi">10</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="kr">const</span> <span class="nx">toc</span> <span class="o">=</span> <span class="nb">parseInt</span><span class="p">(</span><span class="nx">getComputedStyle</span><span class="p">(</span><span class="nb">document</span><span class="p">.</span><span class="nx">body</span><span class="p">).</span><span class="nx">getPropertyValue</span><span class="p">(</span><span class="s1">&#39;--toc-width&#39;</span><span class="p">),</span> <span class="mi">10</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="kr">const</span> <span class="nx">gap</span> <span class="o">=</span> <span class="nb">parseInt</span><span class="p">(</span><span class="nx">getComputedStyle</span><span class="p">(</span><span class="nb">document</span><span class="p">.</span><span class="nx">body</span><span class="p">).</span><span class="nx">getPropertyValue</span><span class="p">(</span><span class="s1">&#39;--gap&#39;</span><span class="p">),</span> <span class="mi">10</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="kd">function</span> <span class="nx">checkTocPosition</span><span class="p">()</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">width</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">body</span><span class="p">.</span><span class="nx">scrollWidth</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="nx">width</span> <span class="o">-</span> <span class="nx">main</span> <span class="o">-</span> <span class="p">(</span><span class="nx">toc</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="p">(</span><span class="nx">gap</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nb">document</span><span class="p">.</span><span class="nx">getElementById</span><span class="p">(</span><span class="s2">&#34;toc-container&#34;</span><span class="p">).</span><span class="nx">classList</span><span class="p">.</span><span class="nx">add</span><span class="p">(</span><span class="s2">&#34;wide&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nb">document</span><span class="p">.</span><span class="nx">getElementById</span><span class="p">(</span><span class="s2">&#34;toc-container&#34;</span><span class="p">).</span><span class="nx">classList</span><span class="p">.</span><span class="nx">remove</span><span class="p">(</span><span class="s2">&#34;wide&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="kd">function</span> <span class="nx">getOffsetTop</span><span class="p">(</span><span class="nx">element</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="nx">element</span><span class="p">.</span><span class="nx">getClientRects</span><span class="p">().</span><span class="nx">length</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="kd">let</span> <span class="nx">rect</span> <span class="o">=</span> <span class="nx">element</span><span class="p">.</span><span class="nx">getBoundingClientRect</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="kd">let</span> <span class="nx">win</span> <span class="o">=</span> <span class="nx">element</span><span class="p">.</span><span class="nx">ownerDocument</span><span class="p">.</span><span class="nx">defaultView</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="nx">rect</span><span class="p">.</span><span class="nx">top</span> <span class="o">+</span> <span class="nx">win</span><span class="p">.</span><span class="nx">pageYOffset</span><span class="p">;</span>   
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">{{- end }}
</span></span></code></pre></td></tr></table>
</div>
</div><p>其中，后半部分为 js 代码，根据阅读内容滚动并加粗相应标题就由其实现。</p>
<p>然后，添加 css 样式的代码，创建 <code>&lt;your_hugo_site&gt;/assets/css/extended/toc.css</code> 文件，并拷贝以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span><span class="lnt">68
</span><span class="lnt">69
</span><span class="lnt">70
</span><span class="lnt">71
</span><span class="lnt">72
</span><span class="lnt">73
</span><span class="lnt">74
</span><span class="lnt">75
</span><span class="lnt">76
</span><span class="lnt">77
</span><span class="lnt">78
</span><span class="lnt">79
</span><span class="lnt">80
</span><span class="lnt">81
</span><span class="lnt">82
</span><span class="lnt">83
</span><span class="lnt">84
</span><span class="lnt">85
</span><span class="lnt">86
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-css" data-lang="css"><span class="line"><span class="cl"><span class="p">:</span><span class="nd">root</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nv">--nav-width</span><span class="p">:</span> <span class="mi">1380</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="nv">--article-width</span><span class="p">:</span> <span class="mi">650</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="nv">--toc-width</span><span class="p">:</span> <span class="mi">300</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin</span><span class="p">:</span> <span class="mi">0</span> <span class="mi">2</span><span class="kt">px</span> <span class="mi">40</span><span class="kt">px</span> <span class="mi">2</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">border</span><span class="p">:</span> <span class="mi">1</span><span class="kt">px</span> <span class="kc">solid</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">border</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">background</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">entry</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">border-radius</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">radius</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">padding</span><span class="p">:</span> <span class="mf">0.4</span><span class="kt">em</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc-container</span><span class="p">.</span><span class="nc">wide</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">position</span><span class="p">:</span> <span class="kc">absolute</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">height</span><span class="p">:</span> <span class="mi">100</span><span class="kt">%</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">border-right</span><span class="p">:</span> <span class="mi">1</span><span class="kt">px</span> <span class="kc">solid</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">border</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">left</span><span class="p">:</span> <span class="nb">calc</span><span class="p">((</span><span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">toc</span><span class="o">-</span><span class="n">width</span><span class="p">)</span> <span class="o">+</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">gap</span><span class="p">))</span> <span class="o">*</span> <span class="mi">-1</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">top</span><span class="p">:</span> <span class="nb">calc</span><span class="p">(</span><span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">gap</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">width</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">toc</span><span class="o">-</span><span class="n">width</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">wide</span> <span class="p">.</span><span class="nc">toc</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">position</span><span class="p">:</span> <span class="kc">sticky</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">top</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">gap</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">border</span><span class="p">:</span> <span class="kc">unset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">background</span><span class="p">:</span> <span class="kc">unset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">border-radius</span><span class="p">:</span> <span class="kc">unset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">width</span><span class="p">:</span> <span class="mi">100</span><span class="kt">%</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin</span><span class="p">:</span> <span class="mi">0</span> <span class="mi">2</span><span class="kt">px</span> <span class="mi">40</span><span class="kt">px</span> <span class="mi">2</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="nt">details</span> <span class="nt">summary</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">cursor</span><span class="p">:</span> <span class="kc">zoom-in</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin-inline-start</span><span class="p">:</span> <span class="mi">20</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">padding</span><span class="p">:</span> <span class="mi">12</span><span class="kt">px</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="nt">details</span><span class="o">[</span><span class="nt">open</span><span class="o">]</span> <span class="nt">summary</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">font-weight</span><span class="p">:</span> <span class="mi">500</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc-container</span><span class="p">.</span><span class="nc">wide</span> <span class="p">.</span><span class="nc">toc</span> <span class="p">.</span><span class="nc">inner</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin</span><span class="p">:</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">active</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">font-size</span><span class="p">:</span> <span class="mi">110</span><span class="kt">%</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">font-weight</span><span class="p">:</span> <span class="mi">600</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="nt">ul</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">list-style-type</span><span class="p">:</span> <span class="kc">circle</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="p">.</span><span class="nc">inner</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin</span><span class="p">:</span> <span class="mi">0</span> <span class="mi">0</span> <span class="mi">0</span> <span class="mi">20</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">padding</span><span class="p">:</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">15</span><span class="kt">px</span> <span class="mi">15</span><span class="kt">px</span> <span class="mi">20</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">font-size</span><span class="p">:</span> <span class="mi">16</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c">/*目录显示高度*/</span>
</span></span><span class="line"><span class="cl">    <span class="k">max-height</span><span class="p">:</span> <span class="mi">83</span><span class="kt">vh</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">overflow-y</span><span class="p">:</span> <span class="kc">auto</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="p">.</span><span class="nc">inner</span><span class="p">::</span><span class="nd">-webkit-scrollbar-thumb</span> <span class="p">{</span>  <span class="c">/*滚动条*/</span>
</span></span><span class="line"><span class="cl">    <span class="k">background</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">border</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">border</span><span class="p">:</span> <span class="mi">7</span><span class="kt">px</span> <span class="kc">solid</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">theme</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">border-radius</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">radius</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="nt">li</span> <span class="nt">ul</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin-inline-start</span><span class="p">:</span> <span class="nb">calc</span><span class="p">(</span><span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">gap</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">list-style-type</span><span class="p">:</span> <span class="kc">none</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="nt">li</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">list-style</span><span class="p">:</span> <span class="kc">none</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">font-size</span><span class="p">:</span> <span class="mf">0.95</span><span class="kt">rem</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">padding-bottom</span><span class="p">:</span> <span class="mi">5</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">toc</span> <span class="nt">li</span> <span class="nt">a</span><span class="p">:</span><span class="nd">hover</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">color</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">secondary</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>到此为止，目录应该就能在侧边正确显示了🎉🎉。</p>
<h2 id="阅读百分比">阅读百分比</h2>
<p>阅读百分比实现的核心思想就是每当发生滚动事件时，根据滚动条高度计算当前阅读进度。这里我们将进度的数字显示在 TOP 按钮上，TOP 按钮定义在 <code>footer.html</code> 中，因此我们要创建 <code>&lt;your_hugo_site&gt;/layouts/partials/footer.html</code>，将主题中对应位置的 <code>footer.html</code> 内容拷贝进去，然后修改 TOP 按钮相关的代码，原代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl">{{- if (not site.Params.disableScrollToTop) }}
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;#top&#34;</span> <span class="na">aria-label</span><span class="o">=</span><span class="s">&#34;go to top&#34;</span> <span class="na">title</span><span class="o">=</span><span class="s">&#34;Go to Top (Alt + G)&#34;</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;top-link&#34;</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;top-link&#34;</span> <span class="na">accesskey</span><span class="o">=</span><span class="s">&#34;g&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;</span><span class="nt">svg</span> <span class="na">xmlns</span><span class="o">=</span><span class="s">&#34;http://www.w3.org/2000/svg&#34;</span> <span class="na">viewBox</span><span class="o">=</span><span class="s">&#34;0 0 12 6&#34;</span> <span class="na">fill</span><span class="o">=</span><span class="s">&#34;currentColor&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">        <span class="p">&lt;</span><span class="nt">path</span> <span class="na">d</span><span class="o">=</span><span class="s">&#34;M12 6H0l6-6z&#34;</span> <span class="p">/&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;/</span><span class="nt">svg</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">{{- end }}
</span></span></code></pre></td></tr></table>
</div>
</div><p>我们要在其中添加一个用于展示进度的 <code>span</code> 和更新进度的 js 代码，即修改为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl">{{- if (not .Site.Params.disableScrollToTop) }}
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;#top&#34;</span> <span class="na">aria-label</span><span class="o">=</span><span class="s">&#34;go to top&#34;</span> <span class="na">title</span><span class="o">=</span><span class="s">&#34;Go to Top (Alt + G)&#34;</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;top-link&#34;</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;top-link&#34;</span> <span class="na">accesskey</span><span class="o">=</span><span class="s">&#34;g&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;</span><span class="nt">span</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;topInner&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">        <span class="p">&lt;</span><span class="nt">svg</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;topSvg&#34;</span> <span class="na">xmlns</span><span class="o">=</span><span class="s">&#34;http://www.w3.org/2000/svg&#34;</span> <span class="na">viewBox</span><span class="o">=</span><span class="s">&#34;0 0 12 6&#34;</span> <span class="na">fill</span><span class="o">=</span><span class="s">&#34;currentColor&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">            <span class="p">&lt;</span><span class="nt">path</span> <span class="na">d</span><span class="o">=</span><span class="s">&#34;M12 6H0l6-6z&#34;</span><span class="p">/&gt;</span>
</span></span><span class="line"><span class="cl">        <span class="p">&lt;/</span><span class="nt">svg</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">        <span class="p">&lt;</span><span class="nt">span</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;read_progress&#34;</span><span class="p">&gt;&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="nb">document</span><span class="p">.</span><span class="nx">addEventListener</span><span class="p">(</span><span class="s1">&#39;scroll&#39;</span><span class="p">,</span> <span class="kd">function</span> <span class="p">(</span><span class="nx">e</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">readProgress</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">getElementById</span><span class="p">(</span><span class="s2">&#34;read_progress&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">scrollHeight</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">documentElement</span><span class="p">.</span><span class="nx">scrollHeight</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">clientHeight</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">documentElement</span><span class="p">.</span><span class="nx">clientHeight</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="kr">const</span> <span class="nx">scrollTop</span> <span class="o">=</span> <span class="nb">document</span><span class="p">.</span><span class="nx">documentElement</span><span class="p">.</span><span class="nx">scrollTop</span> <span class="o">||</span> <span class="nb">document</span><span class="p">.</span><span class="nx">body</span><span class="p">.</span><span class="nx">scrollTop</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="nx">readProgress</span><span class="p">.</span><span class="nx">innerText</span> <span class="o">=</span> <span class="p">((</span><span class="nx">scrollTop</span> <span class="o">/</span> <span class="p">(</span><span class="nx">scrollHeight</span> <span class="o">-</span> <span class="nx">clientHeight</span><span class="p">)).</span><span class="nx">toFixed</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="p">).</span><span class="nx">toFixed</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="p">})</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">{{- end }}
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后添加相关 css 代码，即创建 <code>&lt;your_hugo_site&gt;/assets/css/extended/top.css</code> 文件，并拷贝以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-css" data-lang="css"><span class="line"><span class="cl"><span class="c">/*top*/</span>
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">topInner</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">display</span><span class="p">:</span> <span class="k">grid</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">align-items</span><span class="p">:</span> <span class="kc">baseline</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">justify-items</span><span class="p">:</span> <span class="kc">center</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">margin</span><span class="p">:</span> <span class="mi">7</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">font-weight</span><span class="p">:</span> <span class="mi">900</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">topSvg</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">width</span><span class="p">:</span> <span class="mi">20</span><span class="kt">px</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">top-link</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">padding</span><span class="p">:</span> <span class="kc">unset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c">/*到顶部*/</span>
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">top-link</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">background</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">entry</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kp">-webkit-</span><span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kp">-moz-</span><span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kp">-o-</span><span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">box-shadow</span><span class="p">:</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">2</span><span class="kt">px</span> <span class="mi">4</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">5</span><span class="kt">%</span><span class="p">),</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">7</span><span class="kt">px</span> <span class="mi">13</span><span class="kt">px</span> <span class="mi">-3</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">30</span><span class="kt">%</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">dark</span> <span class="p">.</span><span class="nc">top-link</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">background</span><span class="p">:</span> <span class="nf">var</span><span class="p">(</span><span class="o">--</span><span class="n">entry</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kp">-webkit-</span><span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kp">-moz-</span><span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kp">-o-</span><span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">box-shadow</span><span class="p">:</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">2</span><span class="kt">px</span> <span class="mi">4</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">5</span><span class="kt">%</span><span class="p">),</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">7</span><span class="kt">px</span> <span class="mi">13</span><span class="kt">px</span> <span class="mi">-3</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">30</span><span class="kt">%</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">top-link</span><span class="p">:</span><span class="nd">hover</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">color</span><span class="p">:</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">108</span><span class="p">,</span> <span class="mi">108</span><span class="p">,</span> <span class="mi">108</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c">/*-webkit-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*-moz-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*-ms-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*-o-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">box-shadow</span><span class="p">:</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">4</span><span class="kt">px</span> <span class="mi">8</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">5</span><span class="kt">%</span><span class="p">),</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">7</span><span class="kt">px</span> <span class="mi">13</span><span class="kt">px</span> <span class="mi">-3</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">30</span><span class="kt">%</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">.</span><span class="nc">dark</span> <span class="p">.</span><span class="nc">top-link</span><span class="p">:</span><span class="nd">hover</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">color</span><span class="p">:</span> <span class="nb">rgba</span><span class="p">(</span><span class="mi">180</span><span class="p">,</span> <span class="mi">181</span><span class="p">,</span> <span class="mi">182</span><span class="p">,</span> <span class="mf">.8</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c">/*-webkit-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*-moz-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*-ms-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">    <span class="c">/*-o-transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c">/*transform: scale(1.1);*/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">transition</span><span class="p">:</span> <span class="k">box-shadow</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">,</span> <span class="k">transform</span> <span class="mf">0.4</span><span class="kt">s</span> <span class="kc">ease</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">box-shadow</span><span class="p">:</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">4</span><span class="kt">px</span> <span class="mi">8</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">5</span><span class="kt">%</span><span class="p">),</span> <span class="mi">0</span><span class="kt">px</span> <span class="mi">7</span><span class="kt">px</span> <span class="mi">13</span><span class="kt">px</span> <span class="mi">-3</span><span class="kt">px</span> <span class="nb">rgb</span><span class="p">(</span><span class="mi">5</span> <span class="mi">10</span> <span class="mi">15</span> <span class="o">/</span> <span class="mi">30</span><span class="kt">%</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>到此为止，阅读进度应该就能在 TOP 按钮上正确显示了🎉🎉。</p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://www.sulvblog.cn/posts/blog/hugo_toc_side/">Hugo博客目录放在侧边 | PaperMod主题 | Sulv&rsquo;s Blog</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>CMU 10-414 Assignments 实验笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-cmu-10-414-assignments/</link>
      <pubDate>Thu, 06 Jun 2024 13:28:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cmu-10-414-assignments/</guid>
      <description>&lt;h1 id=&#34;前言&#34;&gt;前言&lt;/h1&gt;
&lt;p&gt;本文记录了完成《CMU 10-414/714 Deep Learning System》配套 Assignments 的过程和对应笔记。共有 6 个 hw，循序渐进地从头实现了一个深度学习框架，并利用搭建 DL 中厂常见的网络模型，包括 CNN、RNN、Transformer 等。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="前言">前言</h1>
<p>本文记录了完成《CMU 10-414/714 Deep Learning System》配套 Assignments 的过程和对应笔记。共有 6 个 hw，循序渐进地从头实现了一个深度学习框架，并利用搭建 DL 中厂常见的网络模型，包括 CNN、RNN、Transformer 等。</p>
<p>实验环境为 Ubuntu 24 @ WSL2。</p>
<p>由于官方自动评分系统目前不再接受非选课学生注册，因此本代码仅保证能够通过已有测试样例。</p>
<h1 id="资源存档">资源存档</h1>
<p>源码来自官方：<a href="https://dlsyscourse.org/assignments/">Assignments</a></p>
<p>所有代码均上传至 <a href="https://gitee.com/littleherozzzx/cmu10-414-assignments">cmu10-414-assignments: cmu10-414-assignments</a>，如官网撤包，可通过 git 回滚获取原始代码。</p>
<h1 id="hw0">hw0</h1>
<p>第一个 homework 共需完成 7 个函数，第一个很简单，用于熟悉评测系统，直接从第二个函数开始。</p>
<h2 id="parse_mnist">parse_mnist</h2>
<p>这个函数签名为：<code>parse_mnist(image_filename, label_filename)</code>，用于读取 MNIST 手写数据集。<a href="http://yann.lecun.com/exdb/mnist/">官网</a> 对数据集格式有详细介绍，直接下拉到 FILE FORMATS FOR THE MNIST DATABASE 这部分即可。</p>
<p>整个数据集分为训练集和测试集，包括数字图像和标签。标签文件内前 8Byte 记录了 magic number 和 number of items，之后按照每个样本占 1Byte 的格式组织。图像文件内前 16Byte 记录了非图像数据，之后按照行优先的顺序按照每个像素占 1Byte 的格式以此排布，每个图片共有 28×28 个像素点。</p>
<p>具体实现中，使用 gzip 库按字节读取数据文件，注意整个数据集需要进行标准化，即将每个像素的灰度值除以 255。完整实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">parse_mnist</span><span class="p">(</span><span class="n">image_filename</span><span class="p">,</span> <span class="n">label_filename</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">image_file_handle</span> <span class="o">=</span> <span class="n">gzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_filename</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">label_file_handle</span> <span class="o">=</span> <span class="n">gzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">label_filename</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">image_file_handle</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">label_file_handle</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">image_data</span> <span class="o">=</span> <span class="n">image_file_handle</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">label_data</span> <span class="o">=</span> <span class="n">label_file_handle</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">image_file_handle</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">label_file_handle</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">image_data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="o">*</span><span class="mi">28</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">/</span> <span class="mf">255.0</span>
</span></span><span class="line"><span class="cl">    <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">label_data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="softmax_loss">softmax_loss</h2>
<p>这个函数签名为：<code>softmax_loss(Z, y)</code>，需要注意的是它计算的是 softmax 损失，或者说是交叉熵损失，而不是进行 softmax 归一化。</p>
<p>照着公式写两行代码即可，不用再赘述：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">softmax_loss</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">rows</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">Z</span><span class="p">[</span><span class="n">rows</span><span class="p">,</span> <span class="n">y</span><span class="p">]</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">Z</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="softmax_regression_epoch">softmax_regression_epoch</h2>
<p>这个函数签名为：<code>softmax_regression_epoch(X, y, theta, lr = 0.1, batch=100)</code>，要实现的是 softmax 回归一个 epoch 上的训练过程。</p>
<p>首先计算出总的 batch 数，并进行这么多次的循环。在每个循环内，先从 X 和 y 中取出对应样本，然后根据公式计算即可。这里涉及到将 label 转换为独热编码的一个小技巧：<code>E_batch = np.eye(theta.shape[1])[y_batch]</code>，其它则比较简单：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">softmax_regression_epoch</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">theta</span><span class="p">,</span> <span class="n">lr</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">total_batches</span> <span class="o">=</span> <span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">batch</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">total_batches</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_batch</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">batch</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">batch</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_batch</span> <span class="o">=</span> <span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">batch</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">batch</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">E_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">theta</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])[</span><span class="n">y_batch</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">logits</span> <span class="o">=</span> <span class="n">X_batch</span> <span class="o">@</span> <span class="n">theta</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z_batch</span> <span class="o">/=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">Z_batch</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">gradients</span> <span class="o">=</span> <span class="n">X_batch</span><span class="o">.</span><span class="n">T</span> <span class="o">@</span> <span class="p">(</span><span class="n">Z_batch</span> <span class="o">-</span> <span class="n">E_batch</span><span class="p">)</span> <span class="o">/</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">        <span class="n">theta</span> <span class="o">-=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">gradients</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="nn_epoch">nn_epoch</h2>
<p>这个函数签名为：<code>nn_epoch(X, y, W1, W2, lr = 0.1, batch=100)</code>，要实现一个双层感知机在一个 epoch 上的训练过程。</p>
<p>跟着公式写代码计算即可，需要注意的两个点：</p>
<ul>
<li>ReLU 激活函数可以使用 max 函数进行实现：<code>Z1_batch = np.maximum(X_batch @ W1, 0)</code></li>
<li>除以 batch_size 这一步应该提前在计算 G2 的过程，如果放在最后更新 $\theta$ 这一步，存在精度误差，不能通过测试点。</li>
</ul>
<p>完整代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">nn_epoch</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">W1</span><span class="p">,</span> <span class="n">W2</span><span class="p">,</span> <span class="n">lr</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">total_batches</span> <span class="o">=</span> <span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">batch</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">total_batches</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_batch</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">batch</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">batch</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_batch</span> <span class="o">=</span> <span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">batch</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">batch</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">E_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">W2</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])[</span><span class="n">y_batch</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z1_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">X_batch</span> <span class="o">@</span> <span class="n">W1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">G2_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">Z1_batch</span> <span class="o">@</span> <span class="n">W2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">G2_batch</span> <span class="o">/=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">G2_batch</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">G2_batch</span> <span class="o">-=</span> <span class="n">E_batch</span>
</span></span><span class="line"><span class="cl">        <span class="n">G2_batch</span> <span class="o">/=</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">        <span class="n">G1_batch</span> <span class="o">=</span> <span class="p">(</span><span class="n">Z1_batch</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">G2_batch</span> <span class="o">@</span> <span class="n">W2</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">gradients_W1</span> <span class="o">=</span> <span class="n">X_batch</span><span class="o">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">G1_batch</span>
</span></span><span class="line"><span class="cl">        <span class="n">gradients_W2</span> <span class="o">=</span> <span class="n">Z1_batch</span><span class="o">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">G2_batch</span>
</span></span><span class="line"><span class="cl">        <span class="n">W1</span> <span class="o">-=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">gradients_W1</span>
</span></span><span class="line"><span class="cl">        <span class="n">W2</span> <span class="o">-=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">gradients_W2</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="softmax_regression_epoch_cpp">softmax_regression_epoch_cpp</h2>
<p>这个函数签名为：<code>void softmax_regression_epoch_cpp(const float *X, const unsigned char *y, float *theta, size_t m, size_t n, size_t k, float lr, size_t batch)</code>，这是一个 softmax 回归在 cpp 上的实现版本。</p>
<p>与 Python 自动处理数组索引越界不同，cpp 版本要分开考虑完整的 batch 和最后一轮不完整的 batch。计算 logits 时，需要使用三轮循环模拟矩阵乘法。cpp 版本的实现可以不写出 $E_y$ 矩阵，梯度计算不用使用矩阵计算，直接使用两层循环，判断 class_idx 是否为正确的 label：<code>softmax[sample_idx * k + class_idx] -= (y[start_idx + sample_idx] == class_idx);</code>。</p>
<p>完整的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">softmax_regression_epoch_cpp</span><span class="p">(</span><span class="k">const</span> <span class="kt">float</span> <span class="o">*</span><span class="n">X</span><span class="p">,</span> <span class="k">const</span> <span class="kt">unsigned</span> <span class="kt">char</span> <span class="o">*</span><span class="n">y</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                  <span class="kt">float</span> <span class="o">*</span><span class="n">theta</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">m</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">n</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">k</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                  <span class="kt">float</span> <span class="n">lr</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">batch</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">total_batches</span> <span class="o">=</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="n">batch</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">batch</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">total_batches</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">size_t</span> <span class="n">start_idx</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="n">batch</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">size_t</span> <span class="n">end_idx</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">min</span><span class="p">(</span><span class="n">start_idx</span> <span class="o">+</span> <span class="n">batch</span><span class="p">,</span> <span class="n">m</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="n">size_t</span> <span class="n">current_batch_size</span> <span class="o">=</span> <span class="n">end_idx</span> <span class="o">-</span> <span class="n">start_idx</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">// Allocate memory for logits and softmax
</span></span></span><span class="line"><span class="cl">        <span class="kt">float</span><span class="o">*</span> <span class="n">logits</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">float</span><span class="p">[</span><span class="n">current_batch_size</span> <span class="o">*</span> <span class="n">k</span><span class="p">]();</span>
</span></span><span class="line"><span class="cl">        <span class="kt">float</span><span class="o">*</span> <span class="n">softmax</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">float</span><span class="p">[</span><span class="n">current_batch_size</span> <span class="o">*</span> <span class="n">k</span><span class="p">]();</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">// Compute logits
</span></span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">sample_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">sample_idx</span> <span class="o">&lt;</span> <span class="n">current_batch_size</span><span class="p">;</span> <span class="n">sample_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">class_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">class_idx</span> <span class="o">&lt;</span> <span class="n">k</span><span class="p">;</span> <span class="n">class_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">feature_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">feature_idx</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="n">feature_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">logits</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">+=</span> <span class="n">X</span><span class="p">[(</span><span class="n">start_idx</span> <span class="o">+</span> <span class="n">sample_idx</span><span class="p">)</span> <span class="o">*</span> <span class="n">n</span> <span class="o">+</span> <span class="n">feature_idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">theta</span><span class="p">[</span><span class="n">feature_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">// Compute softmax
</span></span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">sample_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">sample_idx</span> <span class="o">&lt;</span> <span class="n">current_batch_size</span><span class="p">;</span> <span class="n">sample_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="kt">float</span> <span class="n">max_logit</span> <span class="o">=</span> <span class="o">*</span><span class="n">std</span><span class="o">::</span><span class="n">max_element</span><span class="p">(</span><span class="n">logits</span> <span class="o">+</span> <span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span><span class="p">,</span> <span class="n">logits</span> <span class="o">+</span> <span class="p">(</span><span class="n">sample_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">k</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">            <span class="kt">float</span> <span class="n">sum</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">class_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">class_idx</span> <span class="o">&lt;</span> <span class="n">k</span><span class="p">;</span> <span class="n">class_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">softmax</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">logits</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">-</span> <span class="n">max_logit</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">                <span class="n">sum</span> <span class="o">+=</span> <span class="n">softmax</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">class_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">class_idx</span> <span class="o">&lt;</span> <span class="n">k</span><span class="p">;</span> <span class="n">class_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">softmax</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">/=</span> <span class="n">sum</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">// Compute gradient
</span></span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">sample_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">sample_idx</span> <span class="o">&lt;</span> <span class="n">current_batch_size</span><span class="p">;</span> <span class="n">sample_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">class_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">class_idx</span> <span class="o">&lt;</span> <span class="n">k</span><span class="p">;</span> <span class="n">class_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">softmax</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">-=</span> <span class="p">(</span><span class="n">y</span><span class="p">[</span><span class="n">start_idx</span> <span class="o">+</span> <span class="n">sample_idx</span><span class="p">]</span> <span class="o">==</span> <span class="n">class_idx</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">// Update theta
</span></span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">feature_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">feature_idx</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="n">feature_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">class_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">class_idx</span> <span class="o">&lt;</span> <span class="n">k</span><span class="p">;</span> <span class="n">class_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="kt">float</span> <span class="n">gradient</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">sample_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">sample_idx</span> <span class="o">&lt;</span> <span class="n">current_batch_size</span><span class="p">;</span> <span class="n">sample_idx</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">gradient</span> <span class="o">+=</span> <span class="n">X</span><span class="p">[(</span><span class="n">start_idx</span> <span class="o">+</span> <span class="n">sample_idx</span><span class="p">)</span> <span class="o">*</span> <span class="n">n</span> <span class="o">+</span> <span class="n">feature_idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">softmax</span><span class="p">[</span><span class="n">sample_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">                <span class="n">theta</span><span class="p">[</span><span class="n">feature_idx</span> <span class="o">*</span> <span class="n">k</span> <span class="o">+</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">-=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">gradient</span> <span class="o">/</span> <span class="n">current_batch_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">// Free allocated memory
</span></span></span><span class="line"><span class="cl">        <span class="k">delete</span><span class="p">[]</span> <span class="n">logits</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">delete</span><span class="p">[]</span> <span class="n">softmax</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="hw0-小结">hw0 小结</h2>
<p>hw0 理应是在 Lecture 2 课前完成的，初学者看到一堆公式应该会很懵逼，但整个 hw 比较简单，照着公式一步步走就能完成（除了双层感知机中奇怪的精度错误），主要还是用来熟悉 NumPy 和基本的 DL 模型。</p>
<h1 id="hw1">hw1</h1>
<p>第一个 homework 共有六个小问：正向计算、反向梯度、拓扑排序反向模式自动微分、softmax 损失、双层感知机的 SGD 算法。</p>
<h2 id="implementing-forward--backward-computation">Implementing forward &amp; backward computation</h2>
<p>前两个小问就放在一起讨论了。第一问是通过 NumPy 的 API 实现一些常用的算子，第二问则是通过第一问的算子实现常用算子的梯度实现。</p>
<p>需要注意的是，notebook 中强调了第一问操作的对象是 <code>NDArray</code>，第二问是 <code>Tensor</code>。前者模拟的事这些算子的低层实现，后者则是通过调用这个算子来实现梯度计算，或者说是将梯度计算封装为另一个算子，这样就可以求梯度看作一个普通运算，进而自动求出梯度的梯度。详细解释请看 Lecture 4。</p>
<ul>
<li>PowerScaler<br>
这个算子作用是对张量逐元素求幂。幂指数作为不可学习的参数，在算子实例化时就固定了，因此不用考虑对幂指数的偏导数。这个很简单，应用幂函数的求导公式即可：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">PowerScalar</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Op raise a tensor to an (integer) power.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">=</span> <span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">NDArray</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">*</span> <span class="p">(</span><span class="n">power_scalar</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">out_grad</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>EWiseDiv<br>
这个算子的作用是对张量逐元素求商。梯度计算很简单，即 $a/b$ 分别对 $a$ 和 $b$ 求偏导：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">EWiseDiv</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Op to element-wise divide two nodes.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="o">/</span><span class="n">b</span> <span class="p">,</span> <span class="o">-</span><span class="n">a</span><span class="o">/</span><span class="n">b</span><span class="o">/</span><span class="n">b</span><span class="o">*</span><span class="n">out_grad</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>DivScalar<br>
这个算子的作用是将整个张量同除 scalar，和 <code>PowerScalar</code> 一样，scalar 是不要考虑梯度的：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">DivScalar</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">=</span> <span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="o">/</span><span class="bp">self</span><span class="o">.</span><span class="n">scalar</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>MatMul<br>
这个算子的作用是矩阵乘法。这是这门课程到现在第一个具有挑战性的任务。在计算梯度时，根据课程给出的方法，可以得到如下两个表达式：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">adjoint1</span> <span class="o">=</span> <span class="n">out_grad</span> <span class="o">@</span> <span class="n">transpose</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">adjoint2</span> <span class="o">=</span> <span class="n">transpose</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">@</span> <span class="n">out_grad</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>但但但是，以上只是理论推导。在实际应用中，存在两个问题：1) 矩阵乘法可能是高维矩阵而非二维矩阵相乘，例如 shape 为 (2, 2, 3, 4) 和 (2, 2, 4, 5) 的两个张量相乘；2) 张量乘法过程可能存在广播的情况，这种情况下的梯度怎么处理。</p>
<p>第一个问题，NumPy 基本都为我们处理好了，只要两个张量的<strong>倒数两个维度</strong>符合二维矩阵乘法且<strong>其余维度</strong>（也称为批量维度）相等，或者某个批量维度为 1（会进行广播），它们就可以进行张量乘法运算。</p>
<p>天下没有免费的午餐，自动广播带来便利的同时，也带来了第二个问题。求出的 adjoint 或者说偏导，应该和输入参数的维度一致，但根据公式计算得到的梯度的维度和广播后的维度一样，因此要进行 reduce 操作。</p>
<p>以下是我不严谨且非形式化的 reduce 操作推导：假设矩阵 $A_{m\times n}$ 经过广播后是 $A_{p\times n\times n}^\prime$，实际上参与计算的就是这个 $A^\prime$。首先直接假设在计算图上用 $A^\prime$ 替代 $A$，当 $A^\prime @B$（该节点记为 $f(x_1,&hellip;)$）的某个输入节点 $x_1$ 需要计算梯度时，就会需要计算张量 $\partial f/ \partial x_1$ 和张量 $A^\prime$ 求得的偏导之间的乘积。接下来我们把 $A$ 还原，相对应的，$f(x_1, &hellip;)$ 这个节点计算梯度就要将 $p$ 维度上的偏导数全部加起来，这体现在 $A_{p\times n\times n}^\prime$ 也是将其 $p$ 维度上的元素全部加起来，得到 $A^\prime_{m\times n}$。</p>
<p>上面这段描述不太清晰，总而言之就是要将广播出来的维度全部 sum 掉。</p>
<p>NumPy 中广播新增的维度只会放在最前面，因此只需要计算出要 sum 掉维度的个数，然后取前 $n$ 个维度即可，具体见代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">MatMul</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span><span class="nd">@b</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint1</span> <span class="o">=</span> <span class="n">out_grad</span> <span class="o">@</span> <span class="n">transpose</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint2</span> <span class="o">=</span> <span class="n">transpose</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">@</span> <span class="n">out_grad</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint1</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">adjoint1</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">adjoint1</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">))))</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint2</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">adjoint2</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">adjoint2</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">))))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">adjoint1</span><span class="p">,</span> <span class="n">adjoint2</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Summation<br>
这个算子的作用是对张量的指定维度求和。设带求和的张量 $X$ 的维度为 $s_1\times s_2\times &hellip; \times s_n$，那么求和之后的维度就是移除掉 $axes$ 中指示的维度，形式化表达为：</li>
</ul>


<div>$$

\text{SUM}(X_{s_1\times s_2\times ... \times s_n}, axes) = [\sum_{s_i \in axes} X]_{\{s_j | j\notin axes \}}

$$</div>

<p>假设一个输入为的 shape 为 $3\times 2\times 4 \times 5$，在第 0 和 2 的维度上做 summation，输出的 shape 为 $2\times 5$。反向传播的过程就是先把 <code>out_grad</code> 扩展到 $1\times 2 \times 1\times 5$，然后广播到输入的 shape。</p>
<p>埋个坑，这部分还没有理解，不知道怎么形式化表达求和运算与并对其求导，误打误撞以下代码通过了测试：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Summation</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">axes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">axes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">axes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">axes</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">shape</span><span class="p">[</span><span class="n">_</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">broadcast_to</span><span class="p">(</span><span class="n">reshape</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="n">shape</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>BroadcastTo<br>
这个算子的作用是将张量广播到指定的 shape。所谓广播，就是将数据在不存在或者大小为 1 的维度上复制多份，使之与目标 shape 相匹配。</li>
</ul>
<p>关于广播算子正向和梯度运算的分析，可查看 MatMul 算子，其对广播过程有详细讨论。本算子实现代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">BroadcastTo</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">input_shape</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">ret</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">out_grad</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_shape</span><span class="p">))))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">dim</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">input_shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="n">dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">              <span class="n">ret</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">ret</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="n">i</span><span class="p">,))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">reshape</span><span class="p">(</span><span class="n">ret</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Reshape<br>
这个算子的作用是将张量重整至指定 shape。反向运算则是将张量重整至输入张量的 shape。其代码实现相当简单：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Reshape</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">reshape</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Negate<br>
这个算子作用是将整个张量取相反数，反向运算则是再取一次相反数。其代码实现为：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Negate</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">negative</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">negate</span><span class="p">(</span><span class="n">out_grad</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Transpose<br>
这个算子的作用是交换指定的两个轴，如果没指定则默认为最后两个轴。注意，这个算子的行为与 <code>np.transpose</code> 不一致，需要调用 API 是 <code>np.swapaxes</code>。反向运算则是再次交换这两个轴：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Transpose</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="topological-sort">Topological sort</h2>
<p>这一小问要求实现拓扑排序，涉及的知识点都是数据结构的内容，包括图的拓扑排序、后序遍历和 dfs 算法。</p>
<p>在问题说明中明确要求使用树的后序遍历对算法图求解其拓扑序列，简单来说就是如果本节点存在未访问的子节点（inputs），则先访问子节点，否则访问本节点。所谓访问本节点，就是将其标记为已访问，并将其放入拓扑序列。</p>
<p>结合 dfs 算法，求拓扑序列的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">find_topo_sort</span><span class="p">(</span><span class="n">node_list</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Value</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Value</span><span class="p">]:</span>
</span></span><span class="line"><span class="cl">    <span class="n">visited</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">topo_order</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">node_list</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="n">visited</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">node</span><span class="p">,</span> <span class="kc">False</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">topo_sort_dfs</span><span class="p">(</span><span class="n">node</span><span class="p">,</span> <span class="n">visited</span><span class="p">,</span> <span class="n">topo_order</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">topo_order</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">topo_sort_dfs</span><span class="p">(</span><span class="n">node</span><span class="p">,</span> <span class="n">visited</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">topo_order</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">sons</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">son</span> <span class="ow">in</span> <span class="n">sons</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="n">visited</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">son</span><span class="p">,</span> <span class="kc">False</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">topo_sort_dfs</span><span class="p">(</span><span class="n">son</span><span class="p">,</span> <span class="n">visited</span><span class="p">,</span> <span class="n">topo_order</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">visited</span><span class="p">[</span><span class="n">node</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
</span></span><span class="line"><span class="cl">    <span class="n">topo_order</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">node</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="implementing-reverse-mode-differentiation">Implementing reverse mode differentiation</h2>
<p>终于开始组装我们的自动微分算法了！核心就是理论课中介绍的反向模式 AD 的算法为代码：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406152001945.webp?x-oss-process=image/quality,q_90/format,webp"><br>
其中有几个注意点：</p>
<ul>
<li><code>autograd.py</code> 文件最后一部分提供了一个助手函数 <code>sum_node_list(node_list)</code>，用于在不创造冗余节点的情况下，对一系列 node 求和，对应伪代码中对 $\overline{v_i}$ 求和的部分；</li>
<li>只有存在输入的节点才要计算梯度，初始 input 节点是没法计算梯度的，要进行判断；</li>
<li><del><code>node.op.gradient</code> 返回值类型为 <code>Tuple | Tensor</code>，要分类处理。</del><code>node.op.gradient_as_tuple</code> 辅助函数可确保返回类型为 tuple。</li>
</ul>
<p>在写代码之前，最好复习一遍理论；在 debug 的过程中，可以自己画一下计算图，会有奇效。反向模式 AD 具体实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">compute_gradient_of_variables</span><span class="p">(</span><span class="n">output_tensor</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">reverse_topo_order</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">node</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">sum_node_list</span><span class="p">(</span><span class="n">node_to_output_grads_list</span><span class="p">[</span><span class="n">node</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">gradient</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">node</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">gradient</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">son_node</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                    <span class="n">node_to_output_grads_list</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="n">son_node</span><span class="p">,</span> <span class="p">[])</span>
</span></span><span class="line"><span class="cl">                    <span class="n">node_to_output_grads_list</span><span class="p">[</span><span class="n">son_node</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gradient</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">            <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="n">node_to_output_grads_list</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">[])</span>
</span></span><span class="line"><span class="cl">                <span class="n">node_to_output_grads_list</span><span class="p">[</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gradient</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="softmax-loss">Softmax loss</h2>
<p>本问题先要完成对数函数和指数函数的前向和反向计算，然后再完成 softmax 损失，也就是交叉熵损失函数。</p>
<p>根据说明，这里传入的 y 已经转为了独热编码。具体实现根据说明中的公式一点点写即可，没有要特别说明的：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">softmax_loss</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">y_one_hot</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_size</span> <span class="o">=</span> <span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="n">lhs</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">ndl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">)))</span>
</span></span><span class="line"><span class="cl">    <span class="n">rhs</span> <span class="o">=</span> <span class="p">(</span><span class="n">Z</span> <span class="o">*</span> <span class="n">y_one_hot</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">lhs</span> <span class="o">-</span> <span class="n">rhs</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">loss</span> <span class="o">/</span> <span class="n">batch_size</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="sgd-for-a-two-layer-neural-network">SGD for a two-layer neural network</h2>
<p>最后一问，利用前面的组件，实现一个双层感知机及其随机梯度下降算法。注意事项：</p>
<ul>
<li>这里传入的 y 的值是其 label，需要转为独热编码；</li>
<li>一定要仔细看题！在计算两个权重的更新值时，应该使用 NumPy 计算，再转为 Tensor。如果直接使用 Tensor 算子计算，每次更新都会在计算图上新增好几个节点，并指数级增长，这会导致后面一些要 600 多 batch 的测试点要跑十几分钟，实际上只要几秒钟就能跑完。如果你遇到了同样的问题，请再读一遍题目要求。<br>
代码为：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">nn_epoch</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">W1</span><span class="p">,</span> <span class="n">W2</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_cnt</span> <span class="o">=</span> <span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">batch</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">    <span class="n">num_classes</span> <span class="o">=</span> <span class="n">W2</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="n">one_hot_y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)[</span><span class="n">y</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_cnt</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">start_idx</span> <span class="o">=</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">        <span class="n">end_idx</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="n">batch_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">batch</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_batch</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">start_idx</span><span class="p">:</span><span class="n">end_idx</span><span class="p">,</span> <span class="p">:]</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_batch</span> <span class="o">=</span> <span class="n">one_hot_y</span><span class="p">[</span><span class="n">start_idx</span><span class="p">:</span><span class="n">end_idx</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_tensor</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">X_batch</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_tensor</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">y_batch</span><span class="p">)</span> 
</span></span><span class="line"><span class="cl">        <span class="n">first_logits</span> <span class="o">=</span> <span class="n">X_tensor</span> <span class="o">@</span> <span class="n">W1</span> <span class="c1"># type: ndl.Tensor</span>
</span></span><span class="line"><span class="cl">        <span class="n">first_output</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">first_logits</span><span class="p">)</span> <span class="c1"># type: ndl.Tensor</span>
</span></span><span class="line"><span class="cl">        <span class="n">second_logits</span> <span class="o">=</span> <span class="n">first_output</span> <span class="o">@</span> <span class="n">W2</span> <span class="c1"># type: ndl.Tensor</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss_err</span> <span class="o">=</span> <span class="n">softmax_loss</span><span class="p">(</span><span class="n">second_logits</span><span class="p">,</span> <span class="n">y_tensor</span><span class="p">)</span> <span class="c1"># type: ndl.Tensor</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss_err</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="n">new_W1</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">W1</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">W1</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">        <span class="n">new_W2</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">W2</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">W2</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">        <span class="n">W1</span><span class="p">,</span> <span class="n">W2</span> <span class="o">=</span> <span class="n">new_W1</span><span class="p">,</span> <span class="n">new_W2</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">W1</span><span class="p">,</span> <span class="n">W2</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="hw-1-小结">hw 1 小结</h2>
<p>明显感觉到，这个 hw 的强度上来了。由于不太熟悉 NumPy 的运算，中间查了不少资料和别人的实现。感谢 <a href="https://www.zhihu.com/people/xiao-xiong-34-11">@# xx要努力</a> 的文章 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>，不少都是参考他的实现。</p>
<p>最后双层感知机的调试，由于使用了 Tensor 算子来实现，跑了十几分钟，最后才发现题干已经要求使用 NumPy 运算。长了个很大的教训，下次一定好好读题。</p>
<h1 id="hw2">hw2</h1>
<h2 id="q1-weight-initialization">Q1: Weight Initialization</h2>
<p>Q1 实现的是几种不同的生成参数初始值的方法，结合 <code>init_basic.py</code> 中的辅助函数，照抄 notebook 中给的公式实现，比较简单。注意把 <code>kwargs</code> 传递给辅助函数，里面有 <code>dtype</code>、<code>device</code> 等信息。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">xavier_uniform</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">gain</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">a</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">6</span> <span class="o">/</span> <span class="p">(</span><span class="n">fan_in</span> <span class="o">+</span> <span class="n">fan_out</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">rand</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">a</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">xavier_normal</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">gain</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">std</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">/</span> <span class="p">(</span><span class="n">fan_in</span> <span class="o">+</span> <span class="n">fan_out</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">randn</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="n">std</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">kaiming_uniform</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="s2">&#34;relu&#34;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s2">&#34;relu&#34;</span><span class="p">,</span> <span class="s2">&#34;Only relu supported currently&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s2">&#34;relu&#34;</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">gain</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">bound</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">3</span> <span class="o">/</span> <span class="n">fan_in</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">rand</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">kaiming_normal</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="s2">&#34;relu&#34;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s2">&#34;relu&#34;</span><span class="p">,</span> <span class="s2">&#34;Only relu supported currently&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s2">&#34;relu&#34;</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">gain</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">std</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">fan_in</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">randn</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="n">std</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="q2-nn_basic">Q2: nn_basic</h2>
<p>在 Q2，我们将实现几个最基本的 Module 组件。在 Debug 过程中，我遇到了两个很奇怪问题：</p>
<ul>
<li>所有输入和参数都是 <code>float32</code> 类型，但有一个输出是 <code>float64</code> 类型，导致过不了测试点</li>
<li>反向传播中，有一个 node 接收到的 <code>out_grad</code> 的 shape 比该节点的输入的 shape 大，但理论上来说二者应该是一致的<br>
经过漫长的调试追踪，发现第一个问题是因为在实现 <code>DivScalar</code> 即除法时，如果输入是一个实数而非一个矩阵，<code>numpy</code> 进行除法运算的结果默认为 <code>float64</code>，解决方案是显式调用 <code>np.true_divide</code> 进行除法运算，并使用关键字 <code>dtype='float32'</code> 指定返回值类型。</li>
</ul>
<p>第二个问题是因为 <code>numpy</code> 中许多运算都会进行自动广播，但是该广播操作对我们的 <code>needle</code> 库是不可见的，也无法添加到计算图中，因此导致了反向传播过程的 shape 不匹配。解决方案是修改<strong>修改 Q1 中基础算子的实现</strong>，在计算前检查 shape 是否匹配。修改后的 <code>ops_mathematic.py</code> 文件内容为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">  1
</span><span class="lnt">  2
</span><span class="lnt">  3
</span><span class="lnt">  4
</span><span class="lnt">  5
</span><span class="lnt">  6
</span><span class="lnt">  7
</span><span class="lnt">  8
</span><span class="lnt">  9
</span><span class="lnt"> 10
</span><span class="lnt"> 11
</span><span class="lnt"> 12
</span><span class="lnt"> 13
</span><span class="lnt"> 14
</span><span class="lnt"> 15
</span><span class="lnt"> 16
</span><span class="lnt"> 17
</span><span class="lnt"> 18
</span><span class="lnt"> 19
</span><span class="lnt"> 20
</span><span class="lnt"> 21
</span><span class="lnt"> 22
</span><span class="lnt"> 23
</span><span class="lnt"> 24
</span><span class="lnt"> 25
</span><span class="lnt"> 26
</span><span class="lnt"> 27
</span><span class="lnt"> 28
</span><span class="lnt"> 29
</span><span class="lnt"> 30
</span><span class="lnt"> 31
</span><span class="lnt"> 32
</span><span class="lnt"> 33
</span><span class="lnt"> 34
</span><span class="lnt"> 35
</span><span class="lnt"> 36
</span><span class="lnt"> 37
</span><span class="lnt"> 38
</span><span class="lnt"> 39
</span><span class="lnt"> 40
</span><span class="lnt"> 41
</span><span class="lnt"> 42
</span><span class="lnt"> 43
</span><span class="lnt"> 44
</span><span class="lnt"> 45
</span><span class="lnt"> 46
</span><span class="lnt"> 47
</span><span class="lnt"> 48
</span><span class="lnt"> 49
</span><span class="lnt"> 50
</span><span class="lnt"> 51
</span><span class="lnt"> 52
</span><span class="lnt"> 53
</span><span class="lnt"> 54
</span><span class="lnt"> 55
</span><span class="lnt"> 56
</span><span class="lnt"> 57
</span><span class="lnt"> 58
</span><span class="lnt"> 59
</span><span class="lnt"> 60
</span><span class="lnt"> 61
</span><span class="lnt"> 62
</span><span class="lnt"> 63
</span><span class="lnt"> 64
</span><span class="lnt"> 65
</span><span class="lnt"> 66
</span><span class="lnt"> 67
</span><span class="lnt"> 68
</span><span class="lnt"> 69
</span><span class="lnt"> 70
</span><span class="lnt"> 71
</span><span class="lnt"> 72
</span><span class="lnt"> 73
</span><span class="lnt"> 74
</span><span class="lnt"> 75
</span><span class="lnt"> 76
</span><span class="lnt"> 77
</span><span class="lnt"> 78
</span><span class="lnt"> 79
</span><span class="lnt"> 80
</span><span class="lnt"> 81
</span><span class="lnt"> 82
</span><span class="lnt"> 83
</span><span class="lnt"> 84
</span><span class="lnt"> 85
</span><span class="lnt"> 86
</span><span class="lnt"> 87
</span><span class="lnt"> 88
</span><span class="lnt"> 89
</span><span class="lnt"> 90
</span><span class="lnt"> 91
</span><span class="lnt"> 92
</span><span class="lnt"> 93
</span><span class="lnt"> 94
</span><span class="lnt"> 95
</span><span class="lnt"> 96
</span><span class="lnt"> 97
</span><span class="lnt"> 98
</span><span class="lnt"> 99
</span><span class="lnt">100
</span><span class="lnt">101
</span><span class="lnt">102
</span><span class="lnt">103
</span><span class="lnt">104
</span><span class="lnt">105
</span><span class="lnt">106
</span><span class="lnt">107
</span><span class="lnt">108
</span><span class="lnt">109
</span><span class="lnt">110
</span><span class="lnt">111
</span><span class="lnt">112
</span><span class="lnt">113
</span><span class="lnt">114
</span><span class="lnt">115
</span><span class="lnt">116
</span><span class="lnt">117
</span><span class="lnt">118
</span><span class="lnt">119
</span><span class="lnt">120
</span><span class="lnt">121
</span><span class="lnt">122
</span><span class="lnt">123
</span><span class="lnt">124
</span><span class="lnt">125
</span><span class="lnt">126
</span><span class="lnt">127
</span><span class="lnt">128
</span><span class="lnt">129
</span><span class="lnt">130
</span><span class="lnt">131
</span><span class="lnt">132
</span><span class="lnt">133
</span><span class="lnt">134
</span><span class="lnt">135
</span><span class="lnt">136
</span><span class="lnt">137
</span><span class="lnt">138
</span><span class="lnt">139
</span><span class="lnt">140
</span><span class="lnt">141
</span><span class="lnt">142
</span><span class="lnt">143
</span><span class="lnt">144
</span><span class="lnt">145
</span><span class="lnt">146
</span><span class="lnt">147
</span><span class="lnt">148
</span><span class="lnt">149
</span><span class="lnt">150
</span><span class="lnt">151
</span><span class="lnt">152
</span><span class="lnt">153
</span><span class="lnt">154
</span><span class="lnt">155
</span><span class="lnt">156
</span><span class="lnt">157
</span><span class="lnt">158
</span><span class="lnt">159
</span><span class="lnt">160
</span><span class="lnt">161
</span><span class="lnt">162
</span><span class="lnt">163
</span><span class="lnt">164
</span><span class="lnt">165
</span><span class="lnt">166
</span><span class="lnt">167
</span><span class="lnt">168
</span><span class="lnt">169
</span><span class="lnt">170
</span><span class="lnt">171
</span><span class="lnt">172
</span><span class="lnt">173
</span><span class="lnt">174
</span><span class="lnt">175
</span><span class="lnt">176
</span><span class="lnt">177
</span><span class="lnt">178
</span><span class="lnt">179
</span><span class="lnt">180
</span><span class="lnt">181
</span><span class="lnt">182
</span><span class="lnt">183
</span><span class="lnt">184
</span><span class="lnt">185
</span><span class="lnt">186
</span><span class="lnt">187
</span><span class="lnt">188
</span><span class="lnt">189
</span><span class="lnt">190
</span><span class="lnt">191
</span><span class="lnt">192
</span><span class="lnt">193
</span><span class="lnt">194
</span><span class="lnt">195
</span><span class="lnt">196
</span><span class="lnt">197
</span><span class="lnt">198
</span><span class="lnt">199
</span><span class="lnt">200
</span><span class="lnt">201
</span><span class="lnt">202
</span><span class="lnt">203
</span><span class="lnt">204
</span><span class="lnt">205
</span><span class="lnt">206
</span><span class="lnt">207
</span><span class="lnt">208
</span><span class="lnt">209
</span><span class="lnt">210
</span><span class="lnt">211
</span><span class="lnt">212
</span><span class="lnt">213
</span><span class="lnt">214
</span><span class="lnt">215
</span><span class="lnt">216
</span><span class="lnt">217
</span><span class="lnt">218
</span><span class="lnt">219
</span><span class="lnt">220
</span><span class="lnt">221
</span><span class="lnt">222
</span><span class="lnt">223
</span><span class="lnt">224
</span><span class="lnt">225
</span><span class="lnt">226
</span><span class="lnt">227
</span><span class="lnt">228
</span><span class="lnt">229
</span><span class="lnt">230
</span><span class="lnt">231
</span><span class="lnt">232
</span><span class="lnt">233
</span><span class="lnt">234
</span><span class="lnt">235
</span><span class="lnt">236
</span><span class="lnt">237
</span><span class="lnt">238
</span><span class="lnt">239
</span><span class="lnt">240
</span><span class="lnt">241
</span><span class="lnt">242
</span><span class="lnt">243
</span><span class="lnt">244
</span><span class="lnt">245
</span><span class="lnt">246
</span><span class="lnt">247
</span><span class="lnt">248
</span><span class="lnt">249
</span><span class="lnt">250
</span><span class="lnt">251
</span><span class="lnt">252
</span><span class="lnt">253
</span><span class="lnt">254
</span><span class="lnt">255
</span><span class="lnt">256
</span><span class="lnt">257
</span><span class="lnt">258
</span><span class="lnt">259
</span><span class="lnt">260
</span><span class="lnt">261
</span><span class="lnt">262
</span><span class="lnt">263
</span><span class="lnt">264
</span><span class="lnt">265
</span><span class="lnt">266
</span><span class="lnt">267
</span><span class="lnt">268
</span><span class="lnt">269
</span><span class="lnt">270
</span><span class="lnt">271
</span><span class="lnt">272
</span><span class="lnt">273
</span><span class="lnt">274
</span><span class="lnt">275
</span><span class="lnt">276
</span><span class="lnt">277
</span><span class="lnt">278
</span><span class="lnt">279
</span><span class="lnt">280
</span><span class="lnt">281
</span><span class="lnt">282
</span><span class="lnt">283
</span><span class="lnt">284
</span><span class="lnt">285
</span><span class="lnt">286
</span><span class="lnt">287
</span><span class="lnt">288
</span><span class="lnt">289
</span><span class="lnt">290
</span><span class="lnt">291
</span><span class="lnt">292
</span><span class="lnt">293
</span><span class="lnt">294
</span><span class="lnt">295
</span><span class="lnt">296
</span><span class="lnt">297
</span><span class="lnt">298
</span><span class="lnt">299
</span><span class="lnt">300
</span><span class="lnt">301
</span><span class="lnt">302
</span><span class="lnt">303
</span><span class="lnt">304
</span><span class="lnt">305
</span><span class="lnt">306
</span><span class="lnt">307
</span><span class="lnt">308
</span><span class="lnt">309
</span><span class="lnt">310
</span><span class="lnt">311
</span><span class="lnt">312
</span><span class="lnt">313
</span><span class="lnt">314
</span><span class="lnt">315
</span><span class="lnt">316
</span><span class="lnt">317
</span><span class="lnt">318
</span><span class="lnt">319
</span><span class="lnt">320
</span><span class="lnt">321
</span><span class="lnt">322
</span><span class="lnt">323
</span><span class="lnt">324
</span><span class="lnt">325
</span><span class="lnt">326
</span><span class="lnt">327
</span><span class="lnt">328
</span><span class="lnt">329
</span><span class="lnt">330
</span><span class="lnt">331
</span><span class="lnt">332
</span><span class="lnt">333
</span><span class="lnt">334
</span><span class="lnt">335
</span><span class="lnt">336
</span><span class="lnt">337
</span><span class="lnt">338
</span><span class="lnt">339
</span><span class="lnt">340
</span><span class="lnt">341
</span><span class="lnt">342
</span><span class="lnt">343
</span><span class="lnt">344
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="s2">&#34;&#34;&#34;Operator implementations.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">numbers</span> <span class="kn">import</span> <span class="n">Number</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">..autograd</span> <span class="kn">import</span> <span class="n">NDArray</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">..autograd</span> <span class="kn">import</span> <span class="n">Op</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Value</span><span class="p">,</span> <span class="n">TensorOp</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">..autograd</span> <span class="kn">import</span> <span class="n">TensorTuple</span><span class="p">,</span> <span class="n">TensorTupleOp</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">numpy</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># NOTE: we will import numpy as the array_api</span>
</span></span><span class="line"><span class="cl"><span class="c1"># as the backend for our computations, this line will change in later homeworks</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">array_api</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">EWiseAdd</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span> <span class="p">,</span> <span class="s2">&#34;The shape of lhs </span><span class="si">{}</span><span class="s2"> and rhs </span><span class="si">{}</span><span class="s2"> should be the same&#34;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span> <span class="o">+</span> <span class="n">b</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">node</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">out_grad</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">EWiseAdd</span><span class="p">()(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">AddScalar</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">=</span> <span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">node</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">add_scalar</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">AddScalar</span><span class="p">(</span><span class="n">scalar</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">EWiseMul</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="s2">&#34;The shape of two tensors should be the same&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">b</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">node</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">lhs</span><span class="p">,</span> <span class="n">rhs</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="n">rhs</span><span class="p">,</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="n">lhs</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">multiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">EWiseMul</span><span class="p">()(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">MulScalar</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">=</span> <span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">node</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="p">(</span><span class="n">out_grad</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="p">,)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">mul_scalar</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">MulScalar</span><span class="p">(</span><span class="n">scalar</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">PowerScalar</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Op raise a tensor to an (integer) power.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">=</span> <span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">NDArray</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">*</span> <span class="p">(</span><span class="n">power_scalar</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">out_grad</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">power_scalar</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">PowerScalar</span><span class="p">(</span><span class="n">scalar</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">EWisePow</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Op to element-wise raise a tensor to a power.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">NDArray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">NDArray</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="s2">&#34;The shape of two tensors should be the same&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span><span class="o">**</span><span class="n">b</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">NDArray</span><span class="p">)</span> <span class="ow">or</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">NDArray</span>
</span></span><span class="line"><span class="cl">        <span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&#34;Both inputs must be tensors (NDArray).&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">grad_a</span> <span class="o">=</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="n">b</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span> <span class="o">**</span> <span class="p">(</span><span class="n">b</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">grad_b</span> <span class="o">=</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span><span class="o">**</span><span class="n">b</span><span class="p">)</span> <span class="o">*</span> <span class="n">array_api</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">grad_a</span><span class="p">,</span> <span class="n">grad_b</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">power</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">EWisePow</span><span class="p">()(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">EWiseDiv</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Op to element-wise divide two nodes.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="s2">&#34;The shape of two tensors should be the same&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="o">/</span><span class="n">b</span> <span class="p">,</span> <span class="o">-</span><span class="n">a</span><span class="o">/</span><span class="n">b</span><span class="o">/</span><span class="n">b</span><span class="o">*</span><span class="n">out_grad</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">divide</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">EWiseDiv</span><span class="p">()(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">DivScalar</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span> <span class="o">=</span> <span class="n">scalar</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scalar</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="o">/</span><span class="bp">self</span><span class="o">.</span><span class="n">scalar</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">divide_scalar</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">scalar</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">DivScalar</span><span class="p">(</span><span class="n">scalar</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Transpose</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Transpose</span><span class="p">(</span><span class="n">axes</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Reshape</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">expect_size</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">expect_size</span> <span class="o">*=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">        <span class="n">real_size</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">real_size</span> <span class="o">*=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">expect_size</span> <span class="o">==</span> <span class="n">real_size</span> <span class="p">,</span> <span class="s2">&#34;The reshape size is not compatible&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">reshape</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">reshape</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">BroadcastTo</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> \
</span></span><span class="line"><span class="cl">            <span class="s2">&#34;The target shape&#39;s dimension count </span><span class="si">{}</span><span class="s2"> should be greater than </span><span class="se">\
</span></span></span><span class="line"><span class="cl"><span class="s2">                or equal to the input shape&#39;s dimension count </span><span class="si">{}</span><span class="s2">&#34;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">            <span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">]</span> <span class="ow">or</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> \
</span></span><span class="line"><span class="cl">                <span class="s2">&#34;The input shape </span><span class="si">{}</span><span class="s2"> is not compatible with the target shape </span><span class="si">{}</span><span class="s2">&#34;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">input_shape</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">ret</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">out_grad</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_shape</span><span class="p">))))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="n">ret</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">ret</span><span class="p">,</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">i</span><span class="p">,))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">reshape</span><span class="p">(</span><span class="n">ret</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">broadcast_to</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">BroadcastTo</span><span class="p">(</span><span class="n">shape</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Summation</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">axes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">axes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">axes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">axes</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">shape</span><span class="p">[</span><span class="n">_</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">broadcast_to</span><span class="p">(</span><span class="n">reshape</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="n">shape</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">summation</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Summation</span><span class="p">(</span><span class="n">axes</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">MatMul</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span><span class="nd">@b</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint1</span> <span class="o">=</span> <span class="n">out_grad</span> <span class="o">@</span> <span class="n">transpose</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint2</span> <span class="o">=</span> <span class="n">transpose</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">@</span> <span class="n">out_grad</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint1</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">adjoint1</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">adjoint1</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">))))</span>
</span></span><span class="line"><span class="cl">        <span class="n">adjoint2</span> <span class="o">=</span> <span class="n">summation</span><span class="p">(</span><span class="n">adjoint2</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">adjoint2</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">))))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">adjoint1</span><span class="p">,</span> <span class="n">adjoint2</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">MatMul</span><span class="p">()(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Negate</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">negative</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">negate</span><span class="p">(</span><span class="n">out_grad</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">negate</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Negate</span><span class="p">()(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Log</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span> <span class="o">/</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Log</span><span class="p">()(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Exp</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="n">exp</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">exp</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Exp</span><span class="p">()(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">ReLU</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">relu_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">cached_data</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="n">relu_mask</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">relu</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">ReLU</span><span class="p">()(</span><span class="n">a</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>万事俱备，接下来可以开始完成 Q2 了。</p>
<ul>
<li>Linear<br>
首先要实现一个线性层，其公式为：</li>
</ul>


<div>$$

Y = XW &#43; B

$$</div>

<p>注意 <code>weight</code> 和 <code>bias</code> 都是 <code>Parameter</code> 类型，如果定义为 <code>Tensor</code> 类型，会导致后面实现优化器过不了测试点。该模块代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="o">=</span> <span class="n">in_features</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="n">out_features</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">kaiming_uniform</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">bias</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">kaiming_uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">transpose</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">y</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">y</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">y</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>
<p>ReLU<br>
这个模块很简单，调用 <code>ops.relu</code> 即可。</p>
</li>
<li>
<p>Sequential<br>
这个模块的作用是将多个模块封装进一个模块，由其负责将输入在内部按需计算，并给出最终输出。其实现为：</p>
</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Sequential</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">modules</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">modules</span> <span class="o">=</span> <span class="n">modules</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">y</span> <span class="o">=</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">modules</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">y</span> <span class="o">=</span> <span class="n">module</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">y</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>LogSumExp<br>
这里要实现的是数值稳定版本的 LogSumExp 算子。文档中直接给出了公式，这里我们给出推导过程：</li>
</ul>


<div>$$

\begin{align*}
\log \sum_i \exp(z_i)
&amp;= \log \sum_i \exp(z_i - \max z &#43; \max z)\\
&amp;=\log \sum_i[\exp(z_i - \max z) \cdot \exp(\max z)] \\
&amp;= \log [\sum_i \exp(z_i -\max z)\cdot\exp(\max z)] \\
&amp;=\log \sum_i \exp(z_i -\max z) &#43; \max z
\end{align*}

$$</div>

<p>通过恒等变换，避免了 $\exp$ 指数运算可能导致的数值上溢的问题。</p>
<p>显然，数值稳定版本的梯度和原始公式的梯度一致，直接求导或者根据文章 <a href="/notes/gradient-of-log-sum-exp/">LogSumExp梯度推导</a> 得到其梯度计算公式为：</p>


<div>$$

\begin{align*}
\frac{\partial{f}}{\partial{z_j}}
&amp;=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}\\
&amp;=\exp(z_j - \log \sum_{i=1}^n\exp\hat{z}_i)\\
&amp;=\exp(z_j - f)
\end{align*}

$$</div>

<p>惊喜地发现，LogSumExp 这个函数的梯度可以用其输入和输出来表示，那在代码实现中，只要获取该节点的输入和输出就可以计算出梯度，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">LogSumExp</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">max_z</span> <span class="o">=</span> <span class="n">array_api</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">max_z</span> <span class="o">=</span> <span class="n">max_z</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">array_api</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">array_api</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">Z</span> <span class="o">-</span> <span class="n">max_z</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">))</span> <span class="o">+</span> <span class="n">max_z</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">        <span class="n">z</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="k">else</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">))]</span>
</span></span><span class="line"><span class="cl">        <span class="n">gradient</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="n">node</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="n">gradient</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>SoftmaxLoss<br>
这里实现其是计算 Softmax 损失的模块，在实现过程中可以调用前面实现的数值稳定版本的 LogSumExp，其公式为：</li>
</ul>


<div>$$

\begin{align*}
\ell_\text{softmax}(z,y) = \log \sum_{i=1}^k \exp z_i - z_y
\end{align*}

$$</div>

<p>代码骨架中已经提供了一个将标签转换为度和编码的辅助函数，同时记得求的损失应该是在 batch 上的均值，记得做平均。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SoftmaxLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">label_size</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">one_hot_y</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">label_size</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">true_logits</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">summation</span><span class="p">(</span><span class="n">logits</span> <span class="o">*</span> <span class="n">one_hot_y</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span> <span class="o">-</span> <span class="n">true_logits</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">/</span><span class="n">batch_size</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>LayerNorm1d<br>
这是第一个比较有挑战性的模块，其中涉及大量的 reshape 和广播操作，必须对每个变量的形状都了如指掌。注意，可以默认输入的 shape 为 <code>(batch_size, feature_size)</code>。计算公式为：</li>
</ul>


<div>$$

\begin{align*}
y = w \circ \frac{x_i - \textbf{E}[x]}{((\textbf{Var}[x]&#43;\epsilon)^{1/2})} &#43; b
\end{align*}

$$</div>

<p>根据公式照抄即可，但是要注意中间变量的 shape：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">LayerNorm1d</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">feature_size</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">mean</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span> <span class="o">/</span> <span class="n">feature_size</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">var</span> <span class="o">=</span> <span class="p">(((</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span> <span class="o">/</span> <span class="n">feature_size</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">std_x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">ops</span><span class="o">.</span><span class="n">power_scalar</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">std_x</span> <span class="o">*</span> <span class="n">weight</span> <span class="o">+</span> <span class="n">bias</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Flatten<br>
本模块的作用是保留第一个维度为 batchsize，展平剩下维度。使用 <code>ops.resahpe</code> 实现即可：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Flatten</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl">        <span class="n">elem_cnt</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">            <span class="n">elem_cnt</span> <span class="o">*=</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">X</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">elem_cnt</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>BatchNorm1d<br>
LayerNorm 是在每一个 batch 内部进行标准化操作，而 BatchNorm 是在每一个 feature 内部进行标准化操作。这就导致了每个样本都会对其他样本的推理结果产生影响，因此在推理时应动态计算均值和方差，以供推理时使用。<code>nn.Module</code> 中有一个 <code>training</code> 字段用于标识是否在训练。</li>
</ul>
<p>与 LayerNorm 类似，在实现过程中运用了大量 reshape 和广播操作，要留意中间变量的形状。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">BatchNorm1d</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">batch_size</span><span class="p">,</span> <span class="n">feature_size</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">            <span class="n">mean</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">))</span> <span class="o">/</span> <span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">feature_size</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">            <span class="n">var</span> <span class="o">=</span> <span class="p">(((</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">))</span> <span class="o">/</span> <span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">feature_size</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">*</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">+</span> <span class="n">mean</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">*</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">+</span> <span class="n">var</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">running_var</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">var</span> <span class="o">=</span> <span class="n">var</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">std_x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">ops</span><span class="o">.</span><span class="n">power_scalar</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">std_x</span> <span class="o">*</span> <span class="n">weight</span> <span class="o">+</span> <span class="n">bias</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">std_x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span> <span class="o">/</span> <span class="n">ops</span><span class="o">.</span><span class="n">power_scalar</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">running_var</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">std_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Dropout<br>
Dropout 说白了就是以概率 p 随机丢弃一部分输入，并把剩下的输入进行缩放，以确保下一层的输入期望不变。代码骨架提供了 <code>init.randb</code> 用于生成服从二项分布的布尔序列。代码实现为：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Dropout</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">        <span class="n">mask</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">randb</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">mask</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Residual<br>
残差模块就是将其它模块的输出和输入的和作为新的输出，实现比较简单：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Residual</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fn</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">fn</span> <span class="o">=</span> <span class="n">fn</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">fn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="q3-optimizer-implementation">Q3: Optimizer Implementation</h2>
<p>在本问题中，我们将实现优化器模块。优化器模块的作用是根据 <code>loss.backward()</code> 计算出的梯度，更新模型的参数。</p>
<p>需要注意的是，本模块默认启用 l2 正则化或者说 weight decay，因此梯度等于 <code>param.grad + weight_decay * param</code>。</p>
<ul>
<li>SGD<br>
首先要实现的优化器是随机梯度下降，注意在更新参数时要先使用 <code>data</code> 方法创建该参数的副本，以避免计算图越来越大。这里还使用了移动平均来计算梯度，初始值默认为 0。代码实现如下：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SGD</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">u</span> <span class="o">=</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="n">param</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">u</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                    <span class="bp">self</span><span class="o">.</span><span class="n">u</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="bp">self</span><span class="o">.</span><span class="n">u</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">u</span><span class="p">[</span><span class="n">param</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">*</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">u</span><span class="p">[</span><span class="n">param</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Adam<br>
没什么好说的，照抄公式就行：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Adam</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">params</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">beta1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">beta2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">=</span> <span class="n">beta1</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">=</span> <span class="n">beta2</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">t</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">m</span> <span class="o">=</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">v</span> <span class="o">=</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">t</span> <span class="o">+=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="n">param</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">m</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
</span></span><span class="line"><span class="cl">                    <span class="bp">self</span><span class="o">.</span><span class="n">m</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="n">param</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">v</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
</span></span><span class="line"><span class="cl">                    <span class="bp">self</span><span class="o">.</span><span class="n">v</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">*</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span>
</span></span><span class="line"><span class="cl">                <span class="bp">self</span><span class="o">.</span><span class="n">m</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">m</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad</span><span class="o">.</span><span class="n">data</span>
</span></span><span class="line"><span class="cl">                <span class="bp">self</span><span class="o">.</span><span class="n">v</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">v</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="o">*</span> <span class="n">grad</span><span class="o">.</span><span class="n">data</span>
</span></span><span class="line"><span class="cl">                <span class="n">u_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">m</span><span class="p">[</span><span class="n">param</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">t</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="n">v_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">v</span><span class="p">[</span><span class="n">param</span><span class="p">]</span><span class="o">.</span><span class="n">data</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">t</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="n">u_hat</span><span class="o">.</span><span class="n">data</span> <span class="o">/</span> <span class="p">(</span><span class="n">ndl</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">power_scalar</span><span class="p">(</span><span class="n">v_hat</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">data</span>
</span></span><span class="line"><span class="cl">                
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="q4-dataloader-implementation">Q4: DataLoader Implementation</h2>
<p>在本问题中，我们将实现一些数据处理、Dataset 和 DataLoader 类。Dataset 类用于提供标准接口来访问数据集，DataLoader 类是从数据集读取一个 batch 的迭代器。</p>
<ul>
<li>RandomFlipHorizontal<br>
这个方法是按照概率 p 反转一张图片。注意输入数据的格式是 <code>H*W*C</code>，因此只要使用 <code>np.flip</code> 对 W 轴进行翻转即可。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">RandomFlipHorizontal</span><span class="p">(</span><span class="n">Transform</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Horizonally flip an image, specified as an H x W x C NDArray.
</span></span></span><span class="line"><span class="cl"><span class="s2">        Args:
</span></span></span><span class="line"><span class="cl"><span class="s2">            img: H x W x C NDArray of an image
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns:
</span></span></span><span class="line"><span class="cl"><span class="s2">            H x W x C ndarray corresponding to image flipped with probability self.p
</span></span></span><span class="line"><span class="cl"><span class="s2">        Note: use the provided code to provide randomness, for easier testing
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="n">flip_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">flip_img</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">img</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>RandomCrop<br>
这个方法是对原图进行随机裁剪。其实现裁剪的流程是：先在上下左右填充 <code>padding</code> 个空白像素，然后根据上下偏移量 <code>shift_y</code> 和左右偏移量 <code>shift_y</code>，在填充图中裁切出与原图大小相同的图片。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">RandomCrop</span><span class="p">(</span><span class="n">Transform</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">3</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">=</span> <span class="n">padding</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34; Zero pad and then randomly crop an image.
</span></span></span><span class="line"><span class="cl"><span class="s2">        Args:
</span></span></span><span class="line"><span class="cl"><span class="s2">             img: H x W x C NDArray of an image
</span></span></span><span class="line"><span class="cl"><span class="s2">        Return 
</span></span></span><span class="line"><span class="cl"><span class="s2">            H x W x C NAArray of cliped image
</span></span></span><span class="line"><span class="cl"><span class="s2">        Note: generate the image shifted by shift_x, shift_y specified below
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="n">shift_x</span><span class="p">,</span> <span class="n">shift_y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">low</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">img_size</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span> <span class="s1">&#39;constant&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">+</span> <span class="n">shift_x</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">+</span> <span class="n">shift_x</span> <span class="o">+</span> <span class="n">img_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">+</span> <span class="n">shift_y</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">+</span> <span class="n">shift_y</span> <span class="o">+</span> <span class="n">img_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">:]</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">img</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>MNISTDataset<br>
这里要实现针对 MNIST 数据集的 Dataset 子类，作为其子类，要实现三个方法：<code>__init__</code> 方法初始化图片、标签和数据处理函数、<code>__len__</code> 返回数据集样本数、<code>__getitem__</code> 方法获取指定下标的数据集。</li>
</ul>
<p>要注意的是：1) 使用之前实现的 <code>parse_mnist</code> 方法来解析 MNIST 数据集；2) <code>Dataset</code> 父类提供了 <code>apply_transforms</code> 方法对图片进行处理；3) <code>__getitem__</code> 方法最好支持以列表指定的多下标以批量读取数据集;4) 图片处理函数接受的数据格式是 <code>H*W*C</code>，但 <code>__getitem__</code> 返回值的格式应当为 <code>batch_size*n</code>。</p>
<p>代码实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">MNISTDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">image_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">label_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">transforms</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">y</span> <span class="o">=</span> <span class="n">parse_mnist</span><span class="p">(</span><span class="n">image_filename</span><span class="p">,</span> <span class="n">label_filename</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">object</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="o">*</span><span class="mi">28</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">y</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Dataloader<br>
Dataloader 类是一个迭代器，也挺简单的，见码知义：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">DataLoader</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="sa">r</span><span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">    Data loader. Combines a dataset and a sampler, and provides an iterable over
</span></span></span><span class="line"><span class="cl"><span class="s2">    the given dataset.
</span></span></span><span class="line"><span class="cl"><span class="s2">    Args:
</span></span></span><span class="line"><span class="cl"><span class="s2">        dataset (Dataset): dataset from which to load the data.
</span></span></span><span class="line"><span class="cl"><span class="s2">        batch_size (int, optional): how many samples per batch to load
</span></span></span><span class="line"><span class="cl"><span class="s2">            (default: ``1``).
</span></span></span><span class="line"><span class="cl"><span class="s2">        shuffle (bool, optional): set to ``True`` to have the data reshuffled
</span></span></span><span class="line"><span class="cl"><span class="s2">            at every epoch (default: ``False``).
</span></span></span><span class="line"><span class="cl"><span class="s2">     &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="n">dataset</span><span class="p">:</span> <span class="n">Dataset</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dataset</span><span class="p">:</span> <span class="n">Dataset</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">shuffle</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">shuffle</span> <span class="o">=</span> <span class="n">shuffle</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">shuffle</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">ordering</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">)),</span> 
</span></span><span class="line"><span class="cl">                                           <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">),</span> <span class="n">batch_size</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shuffle</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">ordering</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset</span><span class="p">)),</span> 
</span></span><span class="line"><span class="cl">                                           <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__next__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">&gt;=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ordering</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="k">raise</span> <span class="ne">StopIteration</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="n">Tensor</span><span class="o">.</span><span class="n">make_const</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">ordering</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">index</span><span class="p">]]]</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">+=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">batch</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="q5-mlpresnet-implementation">Q5: MLPResNet Implementation</h2>
<p>到此为止，我们的 needle 库的各基本组件都实现好了，在本问题中，我们将使用他们拼出 MLP ResNet，并在 MNIST 数据集上进行训练。</p>
<ul>
<li>Residual Block<br>
首先是实现一个残差块，按照下图将这一块块积木拼出来就行：<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407241649801.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">ResidualBlock</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">norm</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm1d</span><span class="p">,</span> <span class="n">drop_prob</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="n">nn</span><span class="o">.</span><span class="n">Residual</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                <span class="n">norm</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">drop_prob</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                <span class="n">norm</span><span class="p">(</span><span class="n">dim</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>MLP ResNet<br>
同样是拼积木，注意这里面有 <code>num_blocks</code> 个 Residual Block。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407241652831.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">MLPResNet</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">dim</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">hidden_dim</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">num_blocks</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">norm</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm1d</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">drop_prob</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">        <span class="o">*</span><span class="p">[</span><span class="n">ResidualBlock</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="o">//</span><span class="mi">2</span><span class="p">,</span> <span class="n">norm</span><span class="p">,</span> <span class="n">drop_prob</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_blocks</span><span class="p">)],</span>
</span></span><span class="line"><span class="cl">        <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Epoch<br>
<code>Epoch</code> 方法用来执行一个 epoch 的训练或者推理，并返回平均错误率或者平均损失，这个函数的逻辑是：实例化损失函数 - 从 DataLoader 获取输入 - 模型推理 - 计算损失 - 重置梯度 - 反向传播 - 更新参数 - 计算错误率。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">epoch</span><span class="p">(</span><span class="n">dataloader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">opt</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SoftmaxLoss</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">error_count</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">opt</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss</span> <span class="o">+=</span> <span class="n">batch_loss</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">opt</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">opt</span><span class="o">.</span><span class="n">reset_grad</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">            <span class="n">batch_loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">            <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">y_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">error_count</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">y_pred</span> <span class="o">!=</span> <span class="n">y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">error_count</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="n">loss</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Train MNIST<br>
本方法用于在 MNIST 数据集上训练一个 MLP ResNet，本方法的逻辑是：实例化 Dataset- 实例化 DataLoader- 实例化模型 - 实例化优化器 - 迭代 epoch</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train_mnist</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_size</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">optimizer</span><span class="o">=</span><span class="n">ndl</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">hidden_dim</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">data_dir</span><span class="o">=</span><span class="s2">&#34;data&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">MNISTDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">+</span><span class="s2">&#34;/train-images-idx3-ubyte.gz&#34;</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">+</span><span class="s2">&#34;/train-labels-idx1-ubyte.gz&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">MNISTDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">+</span><span class="s2">&#34;/t10k-images-idx3-ubyte.gz&#34;</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">+</span><span class="s2">&#34;/t10k-labels-idx1-ubyte.gz&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">test_dataloader</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">test_dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">model</span> <span class="o">=</span> <span class="n">MLPResNet</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">opt</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">train_error</span><span class="p">,</span> <span class="n">train_loss</span> <span class="o">=</span> <span class="n">epoch</span><span class="p">(</span><span class="n">train_dataloader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">opt</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">test_error</span><span class="p">,</span> <span class="n">test_loss</span> <span class="o">=</span> <span class="n">epoch</span><span class="p">(</span><span class="n">test_dataloader</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># print(f&#34;Epoch {i+1}/{epochs} Train Error: {train_error:.4f} Train Loss: {train_loss:.4f} Test Error: {test_error:.4f} Test Loss: {test_loss:.4f}&#34;)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">train_error</span><span class="p">,</span> <span class="n">train_loss</span><span class="p">,</span> <span class="n">test_error</span><span class="p">,</span> <span class="n">test_loss</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="hw2-小结">hw2 小结</h2>
<p>到这里，hw2 就已经完结啦。拖拖拖，拖了一个月才做完，本课程的 test 不是很严格，在 Debug hw2 的过程中发现了不少 hw1 中的错误。遇到问题除了自己调试，也建议参考一下别人的实现，能够提升找到问题所在的效率。</p>
<h1 id="hw3">hw3</h1>
<p>在本次实验中，我们将构建一个简单的底层库，用于实现 <code>NDArray</code>。之前我们是用 <code>NunPy</code> 来实现，这次我们将手动实现该 CPU 和 GPU 版本的底层库，并且不调用现有的高度优化的矩阵乘法或其他操作代码。</p>
<h2 id="part-1-python-array-operations">Part 1: Python array operations</h2>
<p>第一个部分是通过 Python 代码修改 <code>strides</code>、<code>shape</code>、<code>offset</code> 字段来实现一些操作，由于不涉及底层，使用 Python 来实现这些方法效率已经够高了。</p>
<p>在实现前，先浏览一遍 <code>ndarray.py</code>，其提供大量辅助函数以简化实现过程。</p>
<ul>
<li>reshape<br>
reshape 操作就是按照另一种方式来解析内存中的连续一维数据。代码骨架提供了 <code>NDArray.as_strided</code> 方法将一个 <code>NDArray</code> 转换为指定 shape 和 strides，还有 <code>NDArray.compact_strides</code> 方法根据 shape 生成紧密排列情况下的 strides。</li>
</ul>
<p>使用以上辅助函数后，reshape 的实现就相当简单：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">reshape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="k">assert</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="n">prod</span><span class="p">(</span><span class="n">new_shape</span><span class="p">),</span> <span class="s2">&#34;Product of shapes must be equal&#34;</span>
</span></span><span class="line"><span class="cl">	<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_compact</span><span class="p">(),</span> <span class="s2">&#34;Matrix must be compact&#34;</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">as_strided</span><span class="p">(</span><span class="n">new_shape</span><span class="p">,</span> <span class="n">NDArray</span><span class="o">.</span><span class="n">compact_strides</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>permute<br>
permute 操作指的是对 <code>NDArray</code> 的轴进行重排列，例如原始轴排列的顺序是 <code>BHWC</code>，按照 (0,3,1,2) 方式重排列，得到的轴的顺序是 <code>BCHW</code>。重排后索引为 <code>[i, j, k, l]</code>，则重排前索引为 <code>[i, k, l, j]</code>。假设重排前的 strides 是 <code>m, n, p, q</code>，那么使用重排前索引得到元素下标为 <code>im+kn+lp+jq = im+jq+kn+lp</code>，即重排后索引对应的 strides 是 <code>m, q, n, p</code>，即将原始 strides 按照指定序列重排即可得到重排后对应的 strides。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_axes</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="n">new_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">new_axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">new_strides</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">new_axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">NDArray</span><span class="o">.</span><span class="n">make</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">new_shape</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">new_strides</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">handle</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_offset</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>broadcast_to<br>
广播操作很好理解，就是将元素在某些维度上复制，例如 <code>(1, 9, 8, 1) -&gt; (9, 9, 8, 2)</code>，那么广播后索引为 <code>(m, n, p, q)</code> 在原始数组上的索引就是 <code>(0, n, p, 0)</code>，即广播的维度上 strides 置为 0 即可实现该效果。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">broadcast_to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">		<span class="n">new_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">		<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="p">),</span> <span class="s2">&#34;Invalid broadcast shape&#34;</span>
</span></span><span class="line"><span class="cl">	<span class="n">new_strides</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="n">new_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">as_strided</span><span class="p">(</span><span class="n">new_shape</span><span class="p">,</span> <span class="n">new_strides</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>__getitem__<br>
getitem 用于获取制定索引的元素，并以 <code>NDArray</code> 的形式返回。这里需要注意的是索引都是 <code>slice</code> 对象，代码已完成了对索引的预处理，保证所有的索引都是标准 <code>slice</code>，即其 <code>start</code>、<code>stop</code>、<code>step</code> 属性都存在，且在对应 shape 范围内。</li>
</ul>
<p>结果的 shape 计算比较简单，计算每个维度上的切片包含几个元素即可。strides 用于根据索引计算索引元素在一维数组中的下标，如果该维度上切片步长不为 1，那相当于每次都要跳过几个元素来访问下个元素，定量计算不难发现，新的 strides 就等于该维度上 <code>slice.step</code> 乘上对应的 strides。</p>
<p>接下来计算 <code>offset</code>，由于切片中存在 <code>start</code> 值，因此如果待访问的索引存在某个维度上索引值小于对应切片上的 <code>start</code> 值的，这个元素不应存在新的 <code>NDArray</code> 上。例如，切片在每个维度上的 <code>start</code> 值为 <code>(2, 3, 4, 5)</code>，那么原始索引 <code>(1, 3, 4, 5)</code> 或者 <code>(2, 3, 4, 1)</code> 都在切片后的首个元素之前，应该被 offset 覆盖。因此，offset 值等于每个维度上的 <code>slice.start</code> 乘上对应的 strides。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idxs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="o">...</span>
</span></span><span class="line"><span class="cl">	<span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">	<span class="n">shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">(</span><span class="n">s</span><span class="o">.</span><span class="n">stop</span> <span class="o">-</span> <span class="n">s</span><span class="o">.</span><span class="n">start</span> <span class="o">+</span> <span class="n">s</span><span class="o">.</span><span class="n">step</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">s</span><span class="o">.</span><span class="n">step</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">idxs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">strides</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">s</span><span class="o">.</span><span class="n">step</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">idxs</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="n">offset</span> <span class="o">=</span> <span class="n">reduce</span><span class="p">(</span><span class="n">operator</span><span class="o">.</span><span class="n">add</span><span class="p">,</span> <span class="p">(</span><span class="n">s</span><span class="o">.</span><span class="n">start</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">idxs</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">NDArray</span><span class="o">.</span><span class="n">make</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">handle</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="n">offset</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-2-cpu-backend---compact-and-setitem">Part 2: CPU Backend - Compact and setitem</h2>
<p>在本部分中，我们将实现 CPU 版本的 <code>compact</code> 和 <code>setitem</code>，前者用于在内存中创建一份紧密排列的数据副本，后者用于在内存中根据给定的数据赋值。</p>
<p>二者有个共同点，就是涉及到可变循环展开。即，由于给定 <code>NDArray</code> 的维度数量是不确定的，无法通过 n 重循环对数据进行遍历。此处我采用的思路是维护一个索引 <code>(0, 0, 0, ..., 0)</code>，每次手动在最后一位执行 +1 操作，当达到对应维度的 <code>shape</code> 值时则进位，直至最高位也向前进位，说明遍历完毕。</p>
<p>这里我定义了两个辅助函数 <code>bool next_index(std::vector&lt;int32_t&gt;&amp; index, const std::vector&lt;int32_t&gt;&amp; shape)</code> 和 <code>size_t index_to_offset(const std::vector&lt;int32_t&gt;&amp; index, const std::vector&lt;int32_t&gt;&amp; strides, const size_t offset)</code>，分别用于遍历索引和将索引转换为下标。二者实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">bool</span> <span class="nf">next_index</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;&amp;</span> <span class="n">index</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;&amp;</span> <span class="n">shape</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Increment the index by one, and return true if the index is still valid
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  index: current index
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  shape: shape of the array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Returns:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  true if the index is still valid, false otherwise
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">index</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="nb">false</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="n">index</span><span class="p">[</span><span class="n">index</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">++</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="n">index</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span> <span class="n">i</span><span class="o">&gt;=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">--</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">index</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]){</span>
</span></span><span class="line"><span class="cl">      <span class="n">index</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="n">index</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">++</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="nb">false</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="k">return</span> <span class="nb">true</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">size_t</span> <span class="nf">index_to_offset</span><span class="p">(</span><span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;&amp;</span> <span class="n">index</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;&amp;</span> <span class="n">strides</span><span class="p">,</span> <span class="k">const</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Convert an index to an offset
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  index: index to convert
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  strides: strides of the array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  offset: offset of the array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Returns:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *  offset of the index
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">res</span> <span class="o">=</span> <span class="n">offset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">index</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">res</span> <span class="o">+=</span> <span class="n">index</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">res</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span> 
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>compact<br>
compact 函数只要在预分配内存的 <code>out</code> 上将每个位置的值写入即可。鉴于 <code>out</code> 在内存中是连续的，可以使用 <code>out_index++</code> 来逐个访问，原始数据则通过上述两个辅助函数进行访问：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Compact</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">shape</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="n">a_index</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">shape</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">0</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">out_index</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">out_index</span> <span class="o">&lt;</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">;</span> <span class="n">out_index</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">a_offset</span> <span class="o">=</span> <span class="n">index_to_offset</span><span class="p">(</span><span class="n">a_index</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">out_index</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">a_offset</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">next_index</span><span class="p">(</span><span class="n">a_index</span><span class="p">,</span> <span class="n">shape</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>setitem<br>
setitem 按照是否为标量有两个版本，但都挺简单，利用好两个辅助函数逐个访问对应下标即可：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseSetitem</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">shape</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="n">out_index</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">shape</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">0</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">a_index</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">a_index</span> <span class="o">&lt;</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">;</span> <span class="n">a_index</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">out_offset</span> <span class="o">=</span> <span class="n">index_to_offset</span><span class="p">(</span><span class="n">out_index</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">out_offset</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">a_index</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">next_index</span><span class="p">(</span><span class="n">out_index</span><span class="p">,</span> <span class="n">shape</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarSetitem</span><span class="p">(</span><span class="k">const</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">shape</span><span class="p">,</span> <span class="n">td</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="n">out_index</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">shape</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">0</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">out_offset</span> <span class="o">=</span> <span class="n">index_to_offset</span><span class="p">(</span><span class="n">out_index</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">out_offset</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">next_index</span><span class="p">(</span><span class="n">out_index</span><span class="p">,</span> <span class="n">shape</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-3-cpu-backend---elementwise-and-scalar-operations">Part 3: CPU Backend - Elementwise and scalar operations</h2>
<p>在本 Part 中，我们将完成一些非常简单的算子的 CPU 版本，本任务主要是用于熟悉在 pybind 中注册 cpp 函数的流程。文档中提到，鼓励使用模板、宏等简化实现。</p>
<p>我没有为每个算子都写一个显式函数声明和定义，我首先实现了 <code>void EwiseOp(const AlignedArray&amp; a, const AlignedArray&amp; b, AlignedArray* out, std::function&lt;scalar_t(scalar_t, scalar_t)&gt; op)</code> 和 <code>void ScalarOp(const AlignedArray&amp; a, scalar_t val, AlignedArray* out, std::function&lt;scalar_t(scalar_t, scalar_t)&gt; op)</code>，分别用于逐元素和统一执行函数 <code>op</code>，通过传入不同的函数 <code>op</code> 可以实现不同的操作。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseOp</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="p">(</span><span class="n">scalar_t</span><span class="p">,</span> <span class="n">scalar_t</span><span class="p">)</span><span class="o">&gt;</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Element-wise operation on two arrays
</span></span></span><span class="line"><span class="cl"><span class="cm">   *
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: first array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   b: second array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: output array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   op: operation to perform
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="n">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarOp</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="p">(</span><span class="n">scalar_t</span><span class="p">,</span> <span class="n">scalar_t</span><span class="p">)</span><span class="o">&gt;</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Element-wise operation on an array and a scalar
</span></span></span><span class="line"><span class="cl"><span class="cm">   *
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   val: scalar
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: output array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   op: operation to perform
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="n">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">val</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>再通过 lambda 表达式对上面这两个函数部分实例化（柯里化），以便其只接受两个参数 <code>a, b</code> 并在 pybind 中注册。</p>
<p>举个栗子，如果想注册一个按元素乘法，那么完整的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span><span class="s">&#34;ewise_mul&#34;</span><span class="p">,</span> <span class="p">[](</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseOp</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">multiplies</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">});</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>从外向内看，<code>m.def</code> 用于在 pybind 中注册一个方法，该方法名由第一个参数指定，即 <code>ewise_mul</code>，第二个参数用于指定对应的 cpp 函数，这里可以接受函数指针、匿名函数等。注意，在 python 我们调用 <code>ewise_mul</code>，只传入两个 <code>NDArray</code>，因此我们需要对接受三个参数的 <code>EwiseOp</code> 柯里化，即传入 <code>std::multiplies&lt;scalar_t&gt;()</code> 给 <code>EwiseOp</code>，并将其封装为一个匿名函数。</p>
<p>注册方法的这一步每次都要创建一个匿名函数，有点复杂了，这一步也能抽象为一个宏，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl">  <span class="cp">#define REGISTER_EWISW_OP(NAME, OP) \
</span></span></span><span class="line"><span class="cl"><span class="cp">    m.def(NAME, [](const AlignedArray&amp; a, const AlignedArray&amp; b, AlignedArray* out) { \
</span></span></span><span class="line"><span class="cl"><span class="cp">      EwiseOp(a, b, out, OP); \
</span></span></span><span class="line"><span class="cl"><span class="cp">    });
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="cp">#define REGISTER_SCALAR_OP(NAME, OP) \
</span></span></span><span class="line"><span class="cl"><span class="cp">    m.def(NAME, [](const AlignedArray&amp; a, scalar_t val, AlignedArray* out) { \
</span></span></span><span class="line"><span class="cl"><span class="cp">      ScalarOp(a, val, out, OP); \
</span></span></span><span class="line"><span class="cl"><span class="cp">    });
</span></span></span><span class="line"><span class="cl">  <span class="cp">#define REGISTER_SINGLE_OP(NAME, OP) \
</span></span></span><span class="line"><span class="cl"><span class="cp">    m.def(NAME, [](const AlignedArray&amp; a, AlignedArray* out) { \
</span></span></span><span class="line"><span class="cl"><span class="cp">      for (size_t i = 0; i &lt; a.size; i++) { \
</span></span></span><span class="line"><span class="cl"><span class="cp">        out-&gt;ptr[i] = OP(a.ptr[i]); \
</span></span></span><span class="line"><span class="cl"><span class="cp">      } \
</span></span></span><span class="line"><span class="cl"><span class="cp">    });
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>上述三个宏，分别用于注册按元素、按标量的双目运算符，和单目运算符在 pybind 中的注册。</p>
<p>应用这些宏，注册所有指定的方法：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl">  <span class="n">REGISTER_EWISW_OP</span><span class="p">(</span><span class="s">&#34;ewise_mul&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">multiplies</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SCALAR_OP</span><span class="p">(</span><span class="s">&#34;scalar_mul&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">multiplies</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_EWISW_OP</span><span class="p">(</span><span class="s">&#34;ewise_div&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">divides</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SCALAR_OP</span><span class="p">(</span><span class="s">&#34;scalar_div&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">divides</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SCALAR_OP</span><span class="p">(</span><span class="s">&#34;scalar_power&#34;</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="p">(</span><span class="o">*</span><span class="p">)(</span><span class="n">scalar_t</span><span class="p">,</span> <span class="n">scalar_t</span><span class="p">)</span><span class="o">&gt;</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">pow</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_EWISW_OP</span><span class="p">(</span><span class="s">&#34;ewise_maximum&#34;</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="p">(</span><span class="o">*</span><span class="p">)(</span><span class="n">scalar_t</span><span class="p">,</span> <span class="n">scalar_t</span><span class="p">)</span><span class="o">&gt;</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">fmax</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SCALAR_OP</span><span class="p">(</span><span class="s">&#34;scalar_maximum&#34;</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="p">(</span><span class="o">*</span><span class="p">)(</span><span class="n">scalar_t</span><span class="p">,</span> <span class="n">scalar_t</span><span class="p">)</span><span class="o">&gt;</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">fmax</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_EWISW_OP</span><span class="p">(</span><span class="s">&#34;ewise_eq&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">equal_to</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SCALAR_OP</span><span class="p">(</span><span class="s">&#34;scalar_eq&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">equal_to</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_EWISW_OP</span><span class="p">(</span><span class="s">&#34;ewise_ge&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">greater_equal</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SCALAR_OP</span><span class="p">(</span><span class="s">&#34;scalar_ge&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">greater_equal</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SINGLE_OP</span><span class="p">(</span><span class="s">&#34;ewise_log&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">log</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SINGLE_OP</span><span class="p">(</span><span class="s">&#34;ewise_exp&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">exp</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">REGISTER_SINGLE_OP</span><span class="p">(</span><span class="s">&#34;ewise_tanh&#34;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">tanh</span><span class="p">);</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意，其中 <code>std::pow</code> 等有多个重载版本，通过 <code>static_cast</code> 关键字可以指定版本。</p>
<h2 id="part-4-cpu-backend---reductions">Part 4: CPU Backend - Reductions</h2>
<p>这里要实现两个归约算子 <code>max</code> 和 <code>sum</code>，为了简化实现，这里只对单个维度进行归约。即便在单个维度上，想要实现归约运算也是相当困难的，因此本任务还进行了简化：在调用归约算子前会将待归约维度重排到最后一个维度上，并在调用结束后自动恢复，因此我们只要实现对最后一个维度的归约运算。</p>
<p>经过一系列简化操作，这两个算子实现起来有点过于简单了：对连续的 <code>reduce_size</code> 个元素进行 max/sum 运算作为输出的新元素即可，最后记得在 pybind 中注册这两个方法：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ReduceMax</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">reduce_size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">reduce_size</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">1</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">reduce_size</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">max</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">reduce_size</span> <span class="o">+</span> <span class="n">j</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ReduceSum</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">reduce_size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">reduce_size</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">reduce_size</span> <span class="o">+</span> <span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-5-cpu-backend---matrix-multiplication">Part 5: CPU Backend - Matrix multiplication</h2>
<p>在本模块中，我们将实现矩阵乘法。</p>
<ul>
<li>Matmul<br>
首先要实现的是三重循环版本的矩阵乘法，外层两个循环依次为 <code>out</code> 的行和列，在开始实现之前，记得对 <code>out</code> 数组进行初始化！</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Matmul</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">m</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">n</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="kt">uint32_t</span> <span class="n">p</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">uint32_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">m</span><span class="o">*</span><span class="n">p</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">m</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">p</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">p</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">n</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">[</span><span class="n">k</span><span class="o">*</span><span class="n">p</span> <span class="o">+</span> <span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>AlignedDot<br>
本函数的作用是计算两个 <code>TILE*TILE</code> 的矩阵的矩阵乘法计算结果，并将其加到 <code>out</code> 的对应位置。我们是用三重循环来通过代码实现，而在编译时，其将被优化为向量计算。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kr">inline</span> <span class="kt">void</span> <span class="nf">AlignedDot</span><span class="p">(</span><span class="k">const</span> <span class="kt">float</span><span class="o">*</span> <span class="n">__restrict__</span> <span class="n">a</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                       <span class="k">const</span> <span class="kt">float</span><span class="o">*</span> <span class="n">__restrict__</span> <span class="n">b</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                       <span class="kt">float</span><span class="o">*</span> <span class="n">__restrict__</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="n">a</span> <span class="o">=</span> <span class="p">(</span><span class="k">const</span> <span class="kt">float</span><span class="o">*</span><span class="p">)</span><span class="n">__builtin_assume_aligned</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">TILE</span> <span class="o">*</span> <span class="n">ELEM_SIZE</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">b</span> <span class="o">=</span> <span class="p">(</span><span class="k">const</span> <span class="kt">float</span><span class="o">*</span><span class="p">)</span><span class="n">__builtin_assume_aligned</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">TILE</span> <span class="o">*</span> <span class="n">ELEM_SIZE</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">out</span> <span class="o">=</span> <span class="p">(</span><span class="kt">float</span><span class="o">*</span><span class="p">)</span><span class="n">__builtin_assume_aligned</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">TILE</span> <span class="o">*</span> <span class="n">ELEM_SIZE</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">TILE</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">TILE</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span><span class="o">&lt;</span><span class="n">TILE</span><span class="p">;</span> <span class="n">k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">TILE</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">TILE</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">[</span><span class="n">k</span><span class="o">*</span><span class="n">TILE</span> <span class="o">+</span> <span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>MatmulTiled<br>
这里通过分块来实现矩阵乘法，分块的原理和分块加速的原因在 Lecture 12 都讲过了，此处不再赘述，笔记在：<a href="/notes/notes-on-cmu-10-414-deep-learning-system/#lecture-12">《CMU 10-414 deep learning system》学习笔记 &gt; Lecture 12</a>。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">MatmulTiled</span><span class="p">(</span><span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">AlignedArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">AlignedArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">m</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                 <span class="kt">uint32_t</span> <span class="n">n</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">p</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">uint32_t</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">m</span><span class="o">*</span><span class="n">p</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">m</span><span class="o">/</span><span class="n">TILE</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">p</span><span class="o">/</span><span class="n">TILE</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span> <span class="p">(</span><span class="kt">uint32_t</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span><span class="o">&lt;</span><span class="n">n</span><span class="o">/</span><span class="n">TILE</span><span class="p">;</span> <span class="n">k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">AlignedDot</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">i</span><span class="o">*</span><span class="n">n</span><span class="o">/</span><span class="n">TILE</span> <span class="o">+</span> <span class="n">k</span><span class="p">)</span><span class="o">*</span><span class="n">TILE</span><span class="o">*</span><span class="n">TILE</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span><span class="o">*</span><span class="n">p</span><span class="o">/</span><span class="n">TILE</span> <span class="o">+</span> <span class="n">j</span><span class="p">)</span><span class="o">*</span><span class="n">TILE</span><span class="o">*</span><span class="n">TILE</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">i</span><span class="o">*</span><span class="n">p</span><span class="o">/</span><span class="n">TILE</span> <span class="o">+</span> <span class="n">j</span><span class="p">)</span><span class="o">*</span><span class="n">TILE</span><span class="o">*</span><span class="n">TILE</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-6-gpu-backend---compact-and-setitem">Part 6: GPU Backend - Compact and setitem</h2>
<p>从本 Part 开始，我们要写 CUDA 代码了，第一次接触 CUDA 编程的同学可以看一下这个不到 5 小时的教程 <a href="https://www.bilibili.com/video/BV1sM4y1x7of/?vd_source=1310bba71aaa59915676f56cad6e29d8">CUDA编程基础入门系列（持续更新）_哔哩哔哩_bilibili</a>，快速入门。</p>
<p>本 Part 中，我们将实现 <code>compact</code> 和 <code>setitem</code> 算子。有了之前实现 CPU 版本的经验，先写一个将逻辑索引转换为物理索引的辅助函数：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__device__</span> <span class="n">size_t</span> <span class="nf">indexToMemLocation</span><span class="p">(</span><span class="n">size_t</span> <span class="n">index</span><span class="p">,</span> <span class="n">CudaVec</span> <span class="n">shape</span><span class="p">,</span> <span class="n">CudaVec</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">ret</span> <span class="o">=</span> <span class="n">offset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="n">shape</span><span class="p">.</span><span class="n">size</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span> <span class="n">i</span><span class="o">&gt;=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">--</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">ret</span> <span class="o">+=</span> <span class="p">(</span><span class="n">index</span> <span class="o">%</span> <span class="n">shape</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="n">strides</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">index</span> <span class="o">/=</span> <span class="n">shape</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">ret</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>CompactKernel</code> 根据文档，其作用是将 <code>a</code> 中逻辑下标为 <code>gid</code> 的数据拷贝到 <code>out[gid]</code> 处，注意判断 <code>gid</code> 是否越界，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">CompactKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">CudaVec</span> <span class="n">shape</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                              <span class="n">CudaVec</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">gid</span> <span class="o">&gt;=</span> <span class="n">size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">memLocation</span> <span class="o">=</span> <span class="n">indexToMemLocation</span><span class="p">(</span><span class="n">gid</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="n">memLocation</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>两个 setitem 算子照猫画虎，比较简单，直接贴代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span><span class="lnt">68
</span><span class="lnt">69
</span><span class="lnt">70
</span><span class="lnt">71
</span><span class="lnt">72
</span><span class="lnt">73
</span><span class="lnt">74
</span><span class="lnt">75
</span><span class="lnt">76
</span><span class="lnt">77
</span><span class="lnt">78
</span><span class="lnt">79
</span><span class="lnt">80
</span><span class="lnt">81
</span><span class="lnt">82
</span><span class="lnt">83
</span><span class="lnt">84
</span><span class="lnt">85
</span><span class="lnt">86
</span><span class="lnt">87
</span><span class="lnt">88
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">EwiseSetitemKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">CudaVec</span> <span class="n">shape</span><span class="p">,</span> <span class="n">CudaVec</span> <span class="n">strides</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                              <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: _compact_ array whose items will be written to out
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: non-compact array whose items are to be written
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   shape: shapes of each dimension for a and out
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   strides: strides of the *out* array (not a, which has compact strides)
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   offset: offset of the *out* array (not a, which has zero offset, being compact)
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">memLocation</span> <span class="o">=</span> <span class="n">indexToMemLocation</span><span class="p">(</span><span class="n">gid</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="p">[</span><span class="n">memLocation</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  
</span></span><span class="line"><span class="cl">  
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseSetitem</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">shape</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                  <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Set items in a (non-compact) array using CUDA.  Yyou will most likely want to implement a
</span></span></span><span class="line"><span class="cl"><span class="cm">   * EwiseSetitemKernel() function, similar to those above, that will do the actual work.
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: _compact_ array whose items will be written to out
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: non-compact array whose items are to be written
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   shape: shapes of each dimension for a and out
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   strides: strides of the *out* array (not a, which has compact strides)
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   offset: offset of the *out* array (not a, which has zero offset, being compact)
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">EwiseSetitemKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">,</span> <span class="n">VecToCuda</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                                         <span class="n">VecToCuda</span><span class="p">(</span><span class="n">strides</span><span class="p">),</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">ScalarSetitemKernel</span><span class="p">(</span><span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">CudaVec</span> <span class="n">shape</span><span class="p">,</span> 
</span></span><span class="line"><span class="cl">                                    <span class="n">CudaVec</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">memLocation</span> <span class="o">=</span> <span class="n">indexToMemLocation</span><span class="p">(</span><span class="n">gid</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="p">[</span><span class="n">memLocation</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarSetitem</span><span class="p">(</span><span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">shape</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                   <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int32_t</span><span class="o">&gt;</span> <span class="n">strides</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">offset</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Set items is a (non-compact) array
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   size: number of elements to write in out array (note that this will note be the same as
</span></span></span><span class="line"><span class="cl"><span class="cm">   *         out.size, because out is a non-compact subset array);  it _will_ be the same as the 
</span></span></span><span class="line"><span class="cl"><span class="cm">   *         product of items in shape, but covenient to just pass it here.
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   val: scalar value to write to
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: non-compact array whose items are to be written
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   shape: shapes of each dimension of out
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   strides: strides of the out array
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   offset: offset of the out array
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">ScalarSetitemKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">VecToCuda</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                                         <span class="n">VecToCuda</span><span class="p">(</span><span class="n">strides</span><span class="p">),</span> <span class="n">offset</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">////////////////////////////////////////////////////////////////////////////////
</span></span></span><span class="line"><span class="cl"><span class="c1">// Elementwise and scalar operations
</span></span></span><span class="line"><span class="cl"><span class="c1">////////////////////////////////////////////////////////////////////////////////
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">EwiseAddKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">b</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span> <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">+</span> <span class="n">b</span><span class="p">[</span><span class="n">gid</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseAdd</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Add together two CUDA array
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">EwiseAddKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-7-cuda-backend---elementwise-and-scalar-operations">Part 7: CUDA Backend - Elementwise and scalar operations</h2>
<p>本 Part 将实现一系列比较简单的单目、双目运算符，重点讲一下如何精简代码。</p>
<p>在 CPU 版本中，我们通过 <code>std::function</code> 动态传入 <code>Op</code> 来实现不同的运算，但在 CUDA 的核函数中是不支持 <code>std</code> 的，因此我们改为通过模板来实现。</p>
<p>分别为逐元素运算和标量运算各写一个模板核函数：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">Op</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="n">EwiseKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">b</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">Op</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span> <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">gid</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">Op</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="n">ScalarKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">Op</span> <span class="n">op</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span> <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">],</span> <span class="n">val</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>CUDA 核函数中调用的其它函数必须也是核函数或者设备函数，因此我们还要为各个算子封装一个类，并重载 <code>()</code> 运算符，以便实例化上述两个模板核函数：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Add</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">y</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">;</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Mul</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">y</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">;</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Div</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">y</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="n">y</span><span class="p">;</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Maximum</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">y</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">);</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Eq</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">y</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">x</span> <span class="o">==</span> <span class="n">y</span><span class="p">;</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Ge</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">y</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">x</span> <span class="o">&gt;=</span> <span class="n">y</span><span class="p">;</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Power</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">scalar_t</span> <span class="n">val</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">Power</span><span class="p">(</span><span class="n">scalar_t</span> <span class="n">v</span><span class="p">)</span> <span class="o">:</span> <span class="n">val</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl">    <span class="n">__device__</span> <span class="n">scalar_t</span> <span class="nf">operator</span><span class="p">()(</span><span class="n">scalar_t</span> <span class="n">x</span><span class="p">,</span> <span class="n">scalar_t</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span> <span class="k">return</span> <span class="n">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">val</span><span class="p">);</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来定义主机端接口，以便注册到 pybind11 中：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseMul</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Mul</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarMul</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ScalarKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Mul</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseDiv</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Div</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarDiv</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ScalarKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Div</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarPower</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ScalarKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Power</span><span class="p">(</span><span class="n">val</span><span class="p">));</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseMaximum</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Maximum</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarMaximum</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ScalarKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Maximum</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseEq</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Eq</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarEq</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ScalarKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Eq</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseGe</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Ge</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ScalarGe</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span> <span class="n">val</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ScalarKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">,</span> <span class="n">Ge</span><span class="p">());</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述是双目运算符的实现，接下来实现单目运算符。单目运算符也可以像双目一样通过模板实现，但 copilot 直接生成了对应代码，我也懒得改：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">EwiseLogKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span> <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">log</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseLog</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseLogKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">EwiseExpKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span> <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseExp</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">EwiseExpKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">EwiseTanhKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">gid</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span> <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">tanh</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">gid</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">EwiseTanh</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">EwiseTanhKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>最后，将本文件最后 <code>m.def</code> 开头的代码取消注释，将对应接口注册到 pybind11 中即可。</p>
<h2 id="part-8-cuda-backend---reductions">Part 8: CUDA Backend - Reductions</h2>
<p>本 Part 将实现两个规约算子 <code>sum</code> 和 <code>max</code>。</p>
<p>和 CPU 版本一样，待归约的元素在内存中是连续排列的。在 CUDA 中，由每个线程负责一个规约任务，其负责的规约范围为 <code>[gid*size, min(gid*size+size, a_size)]</code>，其中 <code>size</code> 是单个线程负责规约的长度，<code>a_size</code> 是输入数据的长度。</p>
<p>核函数中根据具体的规约算子，计算求和或者最大值即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">ReduceMaxKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">a_size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 对a中连续`size`个元素进行规约
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">start</span> <span class="o">=</span> <span class="n">gid</span> <span class="o">*</span> <span class="n">size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">end</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">start</span> <span class="o">+</span> <span class="n">size</span><span class="p">,</span> <span class="n">a_size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">start</span> <span class="o">&lt;</span> <span class="n">end</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">scalar_t</span> <span class="n">max_val</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="n">start</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">i</span><span class="o">=</span><span class="n">start</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">end</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">max_val</span> <span class="o">=</span> <span class="n">max</span><span class="p">(</span><span class="n">max_val</span><span class="p">,</span> <span class="n">a</span><span class="p">[</span><span class="n">i</span><span class="p">]);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="n">max_val</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ReduceMax</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">reduce_size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Reduce by taking maximum over `reduce_size` contiguous blocks.  Even though it is inefficient,
</span></span></span><span class="line"><span class="cl"><span class="cm">   * for simplicity you can perform each reduction in a single CUDA thread.
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: compact array of size a.size = out.size * reduce_size to reduce over
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: compact array to write into
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   redice_size: size of the dimension to reduce over
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">ReduceMaxKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">reduce_size</span><span class="p">,</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">ReduceSumKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">size</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">a_size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 对a中连续`size`个元素进行规约
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">gid</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">start</span> <span class="o">=</span> <span class="n">gid</span> <span class="o">*</span> <span class="n">size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">size_t</span> <span class="n">end</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">start</span> <span class="o">+</span> <span class="n">size</span><span class="p">,</span> <span class="n">a_size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">start</span> <span class="o">&gt;=</span> <span class="n">end</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="c1">// 如果进行初始化，必须只有需要运行线程才能初始化，否则会越界修改数据
</span></span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="n">size_t</span> <span class="n">i</span><span class="o">=</span><span class="n">start</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">end</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span><span class="p">[</span><span class="n">gid</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">ReduceSum</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">reduce_size</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Reduce by taking summation over `reduce_size` contiguous blocks.  Again, for simplicity you 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * can perform each reduction in a single CUDA thread.
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: compact array of size a.size = out.size * reduce_size to reduce over
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: compact array to write into
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   redice_size: size of the dimension to reduce over
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="n">CudaDims</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">CudaOneDim</span><span class="p">(</span><span class="n">out</span><span class="o">-&gt;</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">ReduceSumKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">dim</span><span class="p">.</span><span class="n">grid</span><span class="p">,</span> <span class="n">dim</span><span class="p">.</span><span class="n">block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">reduce_size</span><span class="p">,</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-9-cuda-backend---matrix-multiplication">Part 9: CUDA Backend - Matrix multiplication</h2>
<p>这是最后一个任务，也是最难的一部分。正如文档中所说，想要实现一个矩阵乘法算子还是挺简单的，让每个线程负责一个结果的计算即可。但，如果想使用 cooperative fetching 和 block shared memory register tiling 技术，尤其是按照理论课中提到的伪代码来实现，则要困难得多。</p>
<p>首先贴出理论课中提到的伪代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">mm</span><span class="p">(</span><span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">],</span> <span class="kt">float</span> <span class="n">B</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">],</span> <span class="kt">float</span> <span class="n">C</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">])</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">sA</span><span class="p">[</span><span class="n">S</span><span class="p">][</span><span class="n">L</span><span class="p">],</span> <span class="n">sB</span><span class="p">[</span><span class="n">S</span><span class="p">][</span><span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">c</span><span class="p">[</span><span class="n">V</span><span class="p">][</span><span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">};</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">a</span><span class="p">[</span><span class="n">V</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">yblock</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">xblock</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ko</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ko</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">ko</span> <span class="o">+=</span> <span class="n">S</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// needs to be implemented by thread cooperative fetching
</span></span></span><span class="line"><span class="cl">        <span class="n">sA</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">ko</span> <span class="o">+</span> <span class="n">S</span><span class="p">,</span> <span class="n">yblock</span> <span class="o">*</span> <span class="nl">L</span> <span class="p">:</span> <span class="n">yblock</span> <span class="o">*</span> <span class="n">L</span> <span class="o">+</span> <span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">sB</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">ko</span> <span class="o">+</span> <span class="n">S</span><span class="p">,</span> <span class="n">xblock</span> <span class="o">*</span> <span class="nl">L</span> <span class="p">:</span> <span class="n">xblock</span> <span class="o">*</span> <span class="n">L</span> <span class="o">+</span> <span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ki</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ki</span> <span class="o">&lt;</span> <span class="n">S</span><span class="p">;</span> <span class="o">++</span><span class="n">ki</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">a</span><span class="p">[</span><span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">sA</span><span class="p">[</span><span class="n">ki</span><span class="p">,</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="n">b</span><span class="p">[</span><span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">sB</span><span class="p">[</span><span class="n">ki</span><span class="p">,</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">y</span> <span class="o">&lt;</span> <span class="n">V</span><span class="p">;</span> <span class="o">++</span><span class="n">y</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">x</span> <span class="o">&lt;</span> <span class="n">V</span><span class="p">;</span> <span class="o">++</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">c</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="n">x</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">[</span><span class="n">y</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">[</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">ybase</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">xbase</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">C</span><span class="p">[</span><span class="n">ybase</span> <span class="o">*</span> <span class="nl">V</span> <span class="p">:</span> <span class="n">ybase</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">,</span> <span class="n">xbase</span> <span class="o">*</span> <span class="nl">V</span> <span class="p">:</span> <span class="n">xbase</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="n">c</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408251127789.webp?x-oss-process=image/quality,q_90/format,webp"><br>
如上图所示，我们要计算的是两个长度为 N 的方阵之间的乘法，结果矩阵 C 会被分块为 (L,L) 的子矩阵，每个 block 负责计算一个子矩阵。</p>
<p>为了计算这个子矩阵，索引为 <code>block_x, block_y</code> 的 block 需要用到的数据为 <code>A'=A[L*block_x:L*block_x+L,:]</code> 和 <code>B'=B[:,L*block_x:L*block_x+L]</code>。A&rsquo; 和 B&rsquo; 可能比较大，因此在另一维度上按照长度 S 再次分为 N/S 块，分块后的 shape 分别为 (L,S) 和 (S,L)，二者的矩阵乘法结果的 shape 为 (L,L)，将 N/S 块累加即可得到该 block 负责的子矩阵的结果。</p>
<p>后文将使用矩阵的 shape 来指代该矩阵。</p>
<p>在计算单个 (L,S) 和 (S,L) 的乘法时，每个 block 都会将其对应的数据，即图中 A 和 B 的阴影部分，加载进 block 内线程共享的共享内存中。</p>
<p>通过外积计算单个 (L,S) 和 (S,L) 的乘法，该算法简单说就是从 (L,S) 任取一列，从 (S,L) 中任取一行，进行外积运算。将各种组合方式的外积结果累加，即可实现矩阵乘法。</p>
<p>单个外积运算由 block 内的线程共同完成，如图中所示，每个 thread 负责计算的就是 (V,V) 的更小的矩阵。具体来说，从 (L,S) 任取一列的 shape 为 (L,1)，从 (S,L) 任取一行的 shape 为 (1,L)，对二者按照长度为 V 再次进行分块，即分块为 (V,1) 和 (1,V)shape 的两个矩阵，然后由一个线程负责计算二者的外积，得到 shape 为 (V,V) 的结果。</p>
<p>以上就是理论课伪代码中提到的算法，将其改写为 CUDA 代码时需要考虑各种情况，有如下注意点：</p>
<ul>
<li>理论中提到的需要分块的场景，在实践中可能存在不能完美切分，由余数的情况，需要判断是否越界；</li>
<li>每个 block 要计算的结果子矩阵是根据该 block 在 grid 中的位置确定的，每个 thread 要计算的外积的部分是根据其在 block 中的位置确定的；</li>
<li>理论中的 S 和 L 在代码中均取值为宏定义常量 <code>TILE 4</code>，V 取值为宏定义常量 <code>V 2</code>。</li>
</ul>
<p>代码中写了比较详细的注释，这部分比较复杂，难以单纯通过文字讲明白，如有问题欢迎留言一起讨论。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">  1
</span><span class="lnt">  2
</span><span class="lnt">  3
</span><span class="lnt">  4
</span><span class="lnt">  5
</span><span class="lnt">  6
</span><span class="lnt">  7
</span><span class="lnt">  8
</span><span class="lnt">  9
</span><span class="lnt"> 10
</span><span class="lnt"> 11
</span><span class="lnt"> 12
</span><span class="lnt"> 13
</span><span class="lnt"> 14
</span><span class="lnt"> 15
</span><span class="lnt"> 16
</span><span class="lnt"> 17
</span><span class="lnt"> 18
</span><span class="lnt"> 19
</span><span class="lnt"> 20
</span><span class="lnt"> 21
</span><span class="lnt"> 22
</span><span class="lnt"> 23
</span><span class="lnt"> 24
</span><span class="lnt"> 25
</span><span class="lnt"> 26
</span><span class="lnt"> 27
</span><span class="lnt"> 28
</span><span class="lnt"> 29
</span><span class="lnt"> 30
</span><span class="lnt"> 31
</span><span class="lnt"> 32
</span><span class="lnt"> 33
</span><span class="lnt"> 34
</span><span class="lnt"> 35
</span><span class="lnt"> 36
</span><span class="lnt"> 37
</span><span class="lnt"> 38
</span><span class="lnt"> 39
</span><span class="lnt"> 40
</span><span class="lnt"> 41
</span><span class="lnt"> 42
</span><span class="lnt"> 43
</span><span class="lnt"> 44
</span><span class="lnt"> 45
</span><span class="lnt"> 46
</span><span class="lnt"> 47
</span><span class="lnt"> 48
</span><span class="lnt"> 49
</span><span class="lnt"> 50
</span><span class="lnt"> 51
</span><span class="lnt"> 52
</span><span class="lnt"> 53
</span><span class="lnt"> 54
</span><span class="lnt"> 55
</span><span class="lnt"> 56
</span><span class="lnt"> 57
</span><span class="lnt"> 58
</span><span class="lnt"> 59
</span><span class="lnt"> 60
</span><span class="lnt"> 61
</span><span class="lnt"> 62
</span><span class="lnt"> 63
</span><span class="lnt"> 64
</span><span class="lnt"> 65
</span><span class="lnt"> 66
</span><span class="lnt"> 67
</span><span class="lnt"> 68
</span><span class="lnt"> 69
</span><span class="lnt"> 70
</span><span class="lnt"> 71
</span><span class="lnt"> 72
</span><span class="lnt"> 73
</span><span class="lnt"> 74
</span><span class="lnt"> 75
</span><span class="lnt"> 76
</span><span class="lnt"> 77
</span><span class="lnt"> 78
</span><span class="lnt"> 79
</span><span class="lnt"> 80
</span><span class="lnt"> 81
</span><span class="lnt"> 82
</span><span class="lnt"> 83
</span><span class="lnt"> 84
</span><span class="lnt"> 85
</span><span class="lnt"> 86
</span><span class="lnt"> 87
</span><span class="lnt"> 88
</span><span class="lnt"> 89
</span><span class="lnt"> 90
</span><span class="lnt"> 91
</span><span class="lnt"> 92
</span><span class="lnt"> 93
</span><span class="lnt"> 94
</span><span class="lnt"> 95
</span><span class="lnt"> 96
</span><span class="lnt"> 97
</span><span class="lnt"> 98
</span><span class="lnt"> 99
</span><span class="lnt">100
</span><span class="lnt">101
</span><span class="lnt">102
</span><span class="lnt">103
</span><span class="lnt">104
</span><span class="lnt">105
</span><span class="lnt">106
</span><span class="lnt">107
</span><span class="lnt">108
</span><span class="lnt">109
</span><span class="lnt">110
</span><span class="lnt">111
</span><span class="lnt">112
</span><span class="lnt">113
</span><span class="lnt">114
</span><span class="lnt">115
</span><span class="lnt">116
</span><span class="lnt">117
</span><span class="lnt">118
</span><span class="lnt">119
</span><span class="lnt">120
</span><span class="lnt">121
</span><span class="lnt">122
</span><span class="lnt">123
</span><span class="lnt">124
</span><span class="lnt">125
</span><span class="lnt">126
</span><span class="lnt">127
</span><span class="lnt">128
</span><span class="lnt">129
</span><span class="lnt">130
</span><span class="lnt">131
</span><span class="lnt">132
</span><span class="lnt">133
</span><span class="lnt">134
</span><span class="lnt">135
</span><span class="lnt">136
</span><span class="lnt">137
</span><span class="lnt">138
</span><span class="lnt">139
</span><span class="lnt">140
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">MatmulKernel</span><span class="p">(</span><span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">b</span><span class="p">,</span> <span class="n">scalar_t</span><span class="o">*</span> <span class="n">c</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">M</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">N</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="kt">uint32_t</span> <span class="n">P</span><span class="p">){</span>
</span></span><span class="line"><span class="cl"><span class="cp">#define V 2
</span></span></span><span class="line"><span class="cl"><span class="cp">#define TILE 4
</span></span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 使用分块计算矩阵乘法，按照TILE大小分块
</span></span></span><span class="line"><span class="cl"><span class="cm">   * a: M x N
</span></span></span><span class="line"><span class="cl"><span class="cm">   * b: N x P
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">  <span class="kt">int</span> <span class="n">block_x</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">int</span> <span class="n">block_y</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">int</span> <span class="n">thread_x</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">int</span> <span class="n">thread_y</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">int</span> <span class="n">thread_id</span> <span class="o">=</span> <span class="n">thread_x</span> <span class="o">+</span> <span class="n">thread_y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">int</span> <span class="n">nthreads</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// 每个block负责计算一个子矩阵的结果，具体来说，就是c[block_x*TILE: (block_x+1)*TILE, block_y*TILE: (block_y+1)*TILE]
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 通过累加&#34;outer product&#34;的结果计算这个子矩阵，product的两个元素都是分块后行列子矩阵的一个stripe
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 例如，a按行分块后每一块shape是(TILE, N)，再取一个stripe的shape就是(TILE, TILE)
</span></span></span><span class="line"><span class="cl">  <span class="c1">// outer product每次的步长不是1，而是TILE
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="n">__shared__</span> <span class="n">scalar_t</span> <span class="n">a_shared</span><span class="p">[</span><span class="n">TILE</span><span class="p">][</span><span class="n">TILE</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">  <span class="n">__shared__</span> <span class="n">scalar_t</span> <span class="n">b_shared</span><span class="p">[</span><span class="n">TILE</span><span class="p">][</span><span class="n">TILE</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">  <span class="n">scalar_t</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">V</span><span class="p">][</span><span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">};</span>
</span></span><span class="line"><span class="cl">  <span class="n">scalar_t</span> <span class="n">a_reg</span><span class="p">[</span><span class="n">V</span><span class="p">]</span><span class="o">=</span><span class="p">{</span><span class="mi">0</span><span class="p">},</span> <span class="n">b_reg</span><span class="p">[</span><span class="n">V</span><span class="p">]</span><span class="o">=</span><span class="p">{</span><span class="mi">0</span><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">start</span><span class="o">&lt;</span><span class="n">N</span><span class="p">;</span> <span class="n">start</span><span class="o">+=</span><span class="n">TILE</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// 一共有TILE * TILE个元素要导入，每个线程平均负责(TILE * TILE+nthreads-1)/nthreads个元素
</span></span></span><span class="line"><span class="cl">    <span class="c1">// for (int i=0; i&lt;(TILE * TILE+nthreads-1)/nthreads; i++){
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   int idx = thread_id + i * nthreads; // 在shared中的索引
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   int x = idx / TILE; // 在shared中的索引
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   int y = idx % TILE; // 在shared中的索引
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   // a_shared中的(x, y)相当于a中的(x+block_x*TILE, y+start)
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   // b_shared中的(x, y)相当于b中的(x+start, y+block_y*TILE)
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   if(x+block_x*TILE &lt; M &amp;&amp; y+start &lt; N){
</span></span></span><span class="line"><span class="cl">    <span class="c1">//     a_shared[x][y] = a[(x+block_x*TILE)*N + y+start];
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   }
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   if(x+start &lt; N &amp;&amp; y+block_y*TILE &lt; P){
</span></span></span><span class="line"><span class="cl">    <span class="c1">//     b_shared[x][y] = b[(x+start)*P + y+block_y*TILE];
</span></span></span><span class="line"><span class="cl">    <span class="c1">//   }
</span></span></span><span class="line"><span class="cl">    <span class="c1">// }
</span></span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">thread_id</span><span class="p">;</span> <span class="n">idx</span> <span class="o">&lt;</span> <span class="n">TILE</span> <span class="o">*</span> <span class="n">TILE</span><span class="p">;</span> <span class="n">idx</span> <span class="o">+=</span> <span class="n">nthreads</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">/</span> <span class="n">TILE</span><span class="p">;</span> <span class="c1">// 在shared中的索引
</span></span></span><span class="line"><span class="cl">      <span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">%</span> <span class="n">TILE</span><span class="p">;</span> <span class="c1">// 在shared中的索引
</span></span></span><span class="line"><span class="cl">      <span class="c1">// a_shared中的(x, y)相当于a中的(x+block_x*TILE, y+start)
</span></span></span><span class="line"><span class="cl">      <span class="c1">// b_shared中的(x, y)相当于b中的(x+start, y+block_y*TILE)
</span></span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">x</span><span class="o">+</span><span class="n">block_x</span><span class="o">*</span><span class="n">TILE</span> <span class="o">&lt;</span> <span class="n">M</span> <span class="o">&amp;&amp;</span> <span class="n">y</span><span class="o">+</span><span class="n">start</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="n">a_shared</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="n">y</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">[(</span><span class="n">x</span><span class="o">+</span><span class="n">block_x</span><span class="o">*</span><span class="n">TILE</span><span class="p">)</span><span class="o">*</span><span class="n">N</span> <span class="o">+</span> <span class="n">y</span><span class="o">+</span><span class="n">start</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">x</span><span class="o">+</span><span class="n">start</span> <span class="o">&lt;</span> <span class="n">N</span> <span class="o">&amp;&amp;</span> <span class="n">y</span><span class="o">+</span><span class="n">block_y</span><span class="o">*</span><span class="n">TILE</span> <span class="o">&lt;</span> <span class="n">P</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="n">b_shared</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="n">y</span><span class="p">]</span> <span class="o">=</span> <span class="n">b</span><span class="p">[(</span><span class="n">x</span><span class="o">+</span><span class="n">start</span><span class="p">)</span><span class="o">*</span><span class="n">P</span> <span class="o">+</span> <span class="n">y</span><span class="o">+</span><span class="n">block_y</span><span class="o">*</span><span class="n">TILE</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// 接下来开始计算外积
</span></span></span><span class="line"><span class="cl">    <span class="c1">// 通过遍历a_shared的列和b_shared的行，也就是a_shared的第stripe_i行和b_shared的第stripe_i列
</span></span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">stripe_cnt</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">TILE</span><span class="p">,</span> <span class="n">N</span><span class="o">-</span><span class="n">start</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">stripe_i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">stripe_i</span><span class="o">&lt;</span><span class="n">stripe_cnt</span><span class="p">;</span> <span class="n">stripe_i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// 这个外积由nthreads负责计算，这个外积将stripe_a 和 stripe_b 按照连续的V行/列分块，由每个线程计算
</span></span></span><span class="line"><span class="cl">    <span class="c1">// 接下来把计算V*V的外积结果的要用的数据加载到寄存器数组中
</span></span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">thread_x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">&gt;=</span> <span class="n">TILE</span> <span class="o">||</span> <span class="n">thread_y</span> <span class="o">*</span> <span class="n">V</span> <span class="o">&gt;=</span> <span class="n">TILE</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">continue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">reg_x</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">reg_x</span><span class="o">&lt;</span><span class="n">V</span><span class="p">;</span> <span class="n">reg_x</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">shared_x</span> <span class="o">=</span> <span class="n">reg_x</span> <span class="o">+</span> <span class="n">thread_x</span> <span class="o">*</span> <span class="n">V</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">shared_x</span> <span class="o">&gt;=</span> <span class="n">TILE</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">          <span class="k">break</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">a_reg</span><span class="p">[</span><span class="n">reg_x</span><span class="p">]</span> <span class="o">=</span> <span class="n">a_shared</span><span class="p">[</span><span class="n">shared_x</span><span class="p">][</span><span class="n">stripe_i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// b_reg[reg_x] = b_shared[stripe_i][shared_x];
</span></span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">reg_y</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">reg_y</span><span class="o">&lt;</span><span class="n">V</span><span class="p">;</span> <span class="n">reg_y</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="kt">int</span> <span class="n">shared_y</span> <span class="o">=</span> <span class="n">reg_y</span> <span class="o">+</span> <span class="n">thread_y</span> <span class="o">*</span> <span class="n">V</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">shared_y</span> <span class="o">&gt;=</span> <span class="n">TILE</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">          <span class="n">printf</span><span class="p">(</span><span class="s">&#34;quit: thread id: %d, shared_y: %d, TILE: %d</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">,</span> <span class="n">thread_id</span><span class="p">,</span> <span class="n">shared_y</span><span class="p">,</span> <span class="n">TILE</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">          <span class="k">break</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// a_reg[reg_y] = a_shared[stripe_i][shared_y];
</span></span></span><span class="line"><span class="cl">        <span class="n">b_reg</span><span class="p">[</span><span class="n">reg_y</span><span class="p">]</span> <span class="o">=</span> <span class="n">b_shared</span><span class="p">[</span><span class="n">stripe_i</span><span class="p">][</span><span class="n">shared_y</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">V</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">V</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">          <span class="c1">// 这里“越界”可以不管吧？把c_reg放到结果中的时候再处理
</span></span></span><span class="line"><span class="cl">          <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a_reg</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">b_reg</span><span class="p">[</span><span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="c1">// 把c_reg的结果写入到c中
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">thread_x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">&gt;=</span> <span class="n">TILE</span> <span class="o">||</span> <span class="n">thread_y</span> <span class="o">*</span> <span class="n">V</span> <span class="o">&gt;=</span> <span class="n">TILE</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">V</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">V</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="n">block_x</span> <span class="o">*</span> <span class="n">TILE</span> <span class="o">+</span> <span class="n">thread_x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">i</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="n">block_y</span> <span class="o">*</span> <span class="n">TILE</span> <span class="o">+</span> <span class="n">thread_y</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">j</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">M</span> <span class="o">&amp;&amp;</span> <span class="n">y</span> <span class="o">&lt;</span> <span class="n">P</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="n">c</span><span class="p">[</span><span class="n">x</span><span class="o">*</span><span class="n">P</span> <span class="o">+</span> <span class="n">y</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">break</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">Matmul</span><span class="p">(</span><span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">CudaArray</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">,</span> <span class="n">CudaArray</span><span class="o">*</span> <span class="n">out</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">M</span><span class="p">,</span> <span class="kt">uint32_t</span> <span class="n">N</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="kt">uint32_t</span> <span class="n">P</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="cm">/**
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Multiply two (compact) matrices into an output (also comapct) matrix.  You will want to look
</span></span></span><span class="line"><span class="cl"><span class="cm">   * at the lecture and notes on GPU-based linear algebra to see how to do this.  Since ultimately
</span></span></span><span class="line"><span class="cl"><span class="cm">   * mugrade is just evaluating correctness, you _can_ implement a version that simply parallelizes
</span></span></span><span class="line"><span class="cl"><span class="cm">   * over (i,j) entries in the output array.  However, to really get the full benefit of this
</span></span></span><span class="line"><span class="cl"><span class="cm">   * problem, we would encourage you to use cooperative fetching, shared memory register tiling, 
</span></span></span><span class="line"><span class="cl"><span class="cm">   * and other ideas covered in the class notes.  Note that unlike the tiled matmul function in
</span></span></span><span class="line"><span class="cl"><span class="cm">   * the CPU backend, here you should implement a single function that works across all size
</span></span></span><span class="line"><span class="cl"><span class="cm">   * matrices, whether or not they are a multiple of a tile size.  As with previous CUDA
</span></span></span><span class="line"><span class="cl"><span class="cm">   * implementations, this function here will largely just set up the kernel call, and you should
</span></span></span><span class="line"><span class="cl"><span class="cm">   * implement the logic in a separate MatmulKernel() call.
</span></span></span><span class="line"><span class="cl"><span class="cm">   * 
</span></span></span><span class="line"><span class="cl"><span class="cm">   *
</span></span></span><span class="line"><span class="cl"><span class="cm">   * Args:
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   a: compact 2D array of size m x n
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   b: comapct 2D array of size n x p
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   out: compact 2D array of size m x p to write the output to
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   M: rows of a / out
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   N: columns of a / rows of b
</span></span></span><span class="line"><span class="cl"><span class="cm">   *   P: columns of b / out
</span></span></span><span class="line"><span class="cl"><span class="cm">   */</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="c1">/// BEGIN SOLUTION
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 结果的shape是M*P，每个block负责计算一个TILE*TILE的子矩阵
</span></span></span><span class="line"><span class="cl">  <span class="n">dim3</span> <span class="n">grid_dim</span> <span class="o">=</span> <span class="n">dim3</span><span class="p">((</span><span class="n">M</span> <span class="o">+</span> <span class="n">TILE</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">TILE</span><span class="p">,</span> <span class="p">(</span><span class="n">P</span> <span class="o">+</span> <span class="n">TILE</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">TILE</span><span class="p">,</span> <span class="mi">1</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">dim3</span> <span class="n">block_dim</span> <span class="o">=</span> <span class="n">dim3</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">1</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// dim3 block_dim = dim3(2, 2, 1);
</span></span></span><span class="line"><span class="cl">  <span class="n">MatmulKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">grid_dim</span><span class="p">,</span> <span class="n">block_dim</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">ptr</span><span class="p">,</span> <span class="n">out</span><span class="o">-&gt;</span><span class="n">ptr</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">P</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">/// END SOLUTION
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="hw3-小结">hw3 小结</h2>
<p>本 hw 主要内容是各算子 CPU 和 GPU 版本的底层实现，由于是第一次接触 CUDA 代码，在实现 GPU 版本的矩阵乘法的时候花了不少时间 Debug，调试到最后甚至要头疼昏睡过去。好在皇天不负苦心人，灵感一瞬间它就来了，谁懂这柳暗花明又一村的感觉。特别感谢 <a href="https://www.albresky.cn/">好友</a> 为我讲解矩阵乘法的实现、大半夜不厌其烦地与我一起调试代码。</p>
<h1 id="hw4">hw4</h1>
<p>本实验中，首先将实现一些算子，然后分别实现 CNN 和 RNN 网络，并在数据集上进行训练。</p>
<h2 id="part-1-nd-backend">Part 1: ND Backend</h2>
<p>首先将 <code>src/*</code>、<code>autograd.py</code>、<code>ndarray.py</code> 文件中未实现的方法从之前的 hw 中复制过来，然后在 <code>ops_*.py</code> 中实现之前实现过的 op，大部分只要复制粘贴。</p>
<p>提一下我踩过的坑 <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>：</p>
<ul>
<li><code>autograd.py</code> 中头文件为如下内容，以保证我们这里使用的后端是根据环境变量 <code>NEEDLE_BACKEND</code> 自动切换的，并且不为 NumPy 后端。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">needle</span>
</span></span><span class="line"><span class="cl"><span class="c1"># from .backend_numpy import Device, cpu, all_devices</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">namedtuple</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">numpy</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">needle</span> <span class="kn">import</span> <span class="n">init</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># needle version</span>
</span></span><span class="line"><span class="cl"><span class="n">LAZY_MODE</span> <span class="o">=</span> <span class="kc">False</span>
</span></span><span class="line"><span class="cl"><span class="n">TENSOR_COUNTER</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">.backend_selection</span> <span class="kn">import</span> <span class="n">array_api</span><span class="p">,</span> <span class="n">NDArray</span><span class="p">,</span> <span class="n">default_device</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">.backend_selection</span> <span class="kn">import</span> <span class="n">Device</span><span class="p">,</span> <span class="n">cpu</span><span class="p">,</span> <span class="n">all_devices</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>在 <code>ndarray.py</code> 中 sum 和 max 规约函数是不支持指定多个轴的，需要修改之以便支持多个轴。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">		<span class="k">for</span> <span class="n">axis_</span> <span class="ow">in</span> <span class="n">axis</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">			<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis_</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	<span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">out</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">		<span class="k">for</span> <span class="n">axis_</span> <span class="ow">in</span> <span class="n">axis</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">			<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis_</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	<span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">out</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>在 reshape 之前，要调用 compact</li>
<li>在创建 Tensor 时，要确保其与其它数据的 device 相同</li>
<li>在 <code>autograd.py</code> 中，有一行代码为 <code>__rsub__ = __sub__</code>，其将 Tensor 的 rsub 方法重定向到了 sub 上，然而减法不具备交换律，该行代码是错误的。需要注释该行，并自行定义 rsub 函数。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="fm">__rsub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">needle</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">AddScalar</span><span class="p">(</span><span class="n">other</span><span class="p">)(</span><span class="n">needle</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">Negate</span><span class="p">()(</span><span class="bp">self</span><span class="p">))</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后我们来实现新增的三个 op。</p>
<ul>
<li>tanh<br>
tanh 在我们实现的 backend 中已经有对应的接口了，正向传播直接调用即可。tanh 反向传播公式为：</li>
</ul>


<div>$$

\tanh^\prime(x) = 1-\tanh^2(x)

$$</div>

<p>反向传播中直接用 1 减去 node 的平方即可。需要注意，这里有一个上面提到的坑，也就是要自定义 rsub 函数。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Tanh</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">node</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>stack<br>
stack 函数是将多个相同 shape 的 Tensor 堆叠起来，并且会产生一个新的维度。正向传播实现的思路是先分配一个目标 shape 的 Tensor，然后通过赋值运算将他们放到目标位置。这里预分配时 Tensor 需要指定 device 与输入的 Tensor device 一致。反向传播调用逆运算 split。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Stack</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Concatenates a sequence of arrays along a new dimension.
</span></span></span><span class="line"><span class="cl"><span class="s2">        Parameters:
</span></span></span><span class="line"><span class="cl"><span class="s2">        axis - dimension to concatenate along
</span></span></span><span class="line"><span class="cl"><span class="s2">        All arrays need to be of the same size.
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">args</span><span class="p">:</span> <span class="n">TensorTuple</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">args</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">shape</span> <span class="o">=</span> <span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">arg</span> <span class="ow">in</span> <span class="n">args</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="k">assert</span> <span class="n">arg</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">shape</span><span class="p">,</span> <span class="s2">&#34;The shape of all tensors should be the same&#34;</span>
</span></span><span class="line"><span class="cl">            <span class="n">ret_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">ret_shape</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">args</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">            <span class="n">ret</span> <span class="o">=</span> <span class="n">array_api</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">ret_shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">arg</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="n">slices</span> <span class="o">=</span> <span class="p">[</span><span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">ret_shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="n">slices</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">                <span class="n">ret</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">slices</span><span class="p">)]</span> <span class="o">=</span> <span class="n">arg</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">ret</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>split<br>
split 方法是将指定的一个维度全部拆开，需要注意的是拆开之后的维度不需要 keep dim，也就是要进行一次 reshape 操作，而在 reshape 前是需要显式调用 compact 的。反向传播直接调用 stack 方法即可。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Split</span><span class="p">(</span><span class="n">TensorTupleOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Splits a tensor along an axis into a tuple of tensors.
</span></span></span><span class="line"><span class="cl"><span class="s2">        (The &#34;inverse&#34; of Stack)
</span></span></span><span class="line"><span class="cl"><span class="s2">        Parameters:
</span></span></span><span class="line"><span class="cl"><span class="s2">        axis - dimension to split
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">ret</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">        <span class="n">ret_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">ret_shape</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">]):</span>
</span></span><span class="line"><span class="cl">            <span class="n">slices</span> <span class="o">=</span> <span class="p">[</span><span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">slices</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">            <span class="n">ret</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">A</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">slices</span><span class="p">)])</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">ret_shape</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">ret</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">stack</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Split</span><span class="p">(</span><span class="n">axis</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-2-cifar-10-dataset">Part 2: CIFAR-10 dataset</h2>
<p>在本 Part 中，将完成对 CIFAR-10 数据库的解析。首先从之前的 hw 中复制 <code>python/needle/data/data_transforms.py</code> 和 <code>python/needle/data/data_basic.py</code> 两个文件，并修改 <code>data_basic</code> 中 <code>DataLoader::__next__</code> 方法为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="fm">__next__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">&gt;=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ordering</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="k">raise</span> <span class="ne">StopIteration</span>
</span></span><span class="line"><span class="cl">	<span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="n">Tensor</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">ordering</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">index</span><span class="p">]]]</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">+=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">		<span class="k">return</span> <span class="n">batch</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在之前 hw 中使用 <code>Tensor.make_const</code> 来实现，但其不会根据当前的 backend 自动切换 cached_data 的数据结构。</p>
<p>CIFAR-10 的数据格式参考 <a href="https://web.archive.org/web/20240827001314/https://www.cs.toronto.edu/~kriz/cifar.html">CIFAR-10 and CIFAR-100 datasets</a>，简单来说，按照 <code>batch, channel, height, width</code> 的格式排列。<code>__init__</code> 方法实现参考网站上已经给出的代码读取数据集，然后进行 reshape 和归一化的操作即可，另外两个方法可以直接写出来。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">CIFAR10Dataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">base_folder</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">p</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">transforms</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Parameters:
</span></span></span><span class="line"><span class="cl"><span class="s2">        base_folder - cifar-10-batches-py folder filepath
</span></span></span><span class="line"><span class="cl"><span class="s2">        train - bool, if True load training dataset, else load test dataset
</span></span></span><span class="line"><span class="cl"><span class="s2">        Divide pixel values by 255. so that images are in 0-1 range.
</span></span></span><span class="line"><span class="cl"><span class="s2">        Attributes:
</span></span></span><span class="line"><span class="cl"><span class="s2">        X - numpy array of images
</span></span></span><span class="line"><span class="cl"><span class="s2">        y - numpy array of labels
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">train_names</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;data_batch_1&#39;</span><span class="p">,</span> <span class="s1">&#39;data_batch_2&#39;</span><span class="p">,</span> <span class="s1">&#39;data_batch_3&#39;</span><span class="p">,</span> <span class="s1">&#39;data_batch_4&#39;</span><span class="p">,</span> <span class="s1">&#39;data_batch_5&#39;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">test_names</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;test_batch&#39;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">names</span> <span class="o">=</span> <span class="n">train_names</span> <span class="k">if</span> <span class="n">train</span> <span class="k">else</span> <span class="n">test_names</span>
</span></span><span class="line"><span class="cl">        <span class="n">dicts</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">names</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">base_folder</span><span class="p">,</span> <span class="n">name</span><span class="p">),</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="n">dicts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pickle</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;bytes&#39;</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">d</span><span class="p">[</span><span class="sa">b</span><span class="s1">&#39;data&#39;</span><span class="p">]</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">dicts</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">X</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span> <span class="o">/</span> <span class="mf">255.0</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">d</span><span class="p">[</span><span class="sa">b</span><span class="s1">&#39;labels&#39;</span><span class="p">]</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">dicts</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">object</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns the image, label at given index
</span></span></span><span class="line"><span class="cl"><span class="s2">        Image should be of shape (3, 32, 32)
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">y</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns the total number of examples in the dataset
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-3-convolutional-neural-network">Part 3: Convolutional neural network</h2>
<p>在本 Part 中，我们将首先实现一些算子，然后实现一个 CNN 网络并在 CIFAR 数据集上进行训练。</p>
<ul>
<li>pad<br>
pad 操作逻辑为：首先计算出 out 的 shape，创建一个大小为 shape 的全零 Tensor，然后通过切片将原矩阵赋值到对应位置：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">pad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="n">out_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">	<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">slices</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">	<span class="n">out</span><span class="p">[</span><span class="n">slices</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">out</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>flip<br>
很难解释为什么，但是 flip 操作通过负 strides 和正 offset 就可以实现。具体来说，将需要 flip 的维度的 stride 值取负，offset 值等于需要 flip 的维度的 strides 乘 shape-1 然后求和。可以结合代码理解上面这段话：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1"># ndarray.py</span>
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">flip</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">),</span> <span class="s2">&#34;axes must be a tuple&#34;</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl">	<span class="n">strides</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">axes</span> <span class="k">else</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">	<span class="nb">sum</span> <span class="o">=</span> <span class="n">__builtins__</span><span class="p">[</span><span class="s2">&#34;sum&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">	<span class="n">offset</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">strides</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">out</span> <span class="o">=</span> <span class="n">NDArray</span><span class="o">.</span><span class="n">make</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">handle</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="n">offset</span><span class="p">)</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">out</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># ops_mathematic.py</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Flip</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">axes</span> <span class="o">=</span> <span class="p">(</span><span class="n">axes</span><span class="p">,)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">axes</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">flip</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>通过操纵 offset 和 strides 实现 flip 在数学角度应该是可以证明的，此处不表。</p>
<ul>
<li>dilate/undilate<br>
dilate 操作之前没有接触过，但下边的公式很形象：</li>
</ul>


<div>$$

\begin{bmatrix}
1 &amp; 2 \\
3 &amp; 4
\end{bmatrix}
\Longrightarrow
\begin{bmatrix}
1 &amp; 0 &amp; 2 &amp; 0 \\
0 &amp; 0 &amp; 0 &amp; 0 \\
3 &amp; 0 &amp; 4 &amp; 0 \\
0 &amp; 0 &amp; 0 &amp; 0
\end{bmatrix}

$$</div>

<p>参数 <code>dilation</code> 就是 0 的个数。</p>
<p>这个函数的实现思路与 flip 非常接近，先计算 out 的 shape，然后创建空矩阵，然后通过切片选择目标元素：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Dilate</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">=</span> <span class="n">dilation</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">a</span>
</span></span><span class="line"><span class="cl">        <span class="n">out_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">out_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">+</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">array_api</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">slices</span> <span class="o">=</span> <span class="p">[</span><span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">slices</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">slices</span><span class="p">)]</span> <span class="o">=</span> <span class="n">a</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">undilate</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">dilate</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axes</span><span class="p">,</span> <span class="n">dilation</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Dilate</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">dilation</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">UnDilate</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">=</span> <span class="n">dilation</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">a</span>
</span></span><span class="line"><span class="cl">        <span class="n">out_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">out_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">//=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">+</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">array_api</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">slices</span> <span class="o">=</span> <span class="p">[</span><span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">slices</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">slices</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">dilate</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">undilate</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axes</span><span class="p">,</span> <span class="n">dilation</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">UnDilate</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">dilation</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>dilate 和 undilate 互为逆运算，在计算梯度时互相调用即可。</p>
<ul>
<li>conv<br>
首先处理 padding，不难发现，padding 和 conv 之间具有结合性，即如下两行代码是等价的：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">conv</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">n</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">conv</span><span class="p">(</span><span class="n">pad</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">n</span><span class="p">),</span> <span class="n">W</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>因此，第一步就是将 X 进行 pad，作为新的 X。后面通过 im2col 技术和操作 strides 将 X 和 W 向量化，通过矩阵乘法来实现卷积。上述原理见课程笔记：<a href="https://www.zhouxin.space/notes/notes-on-cmu-10-414-deep-learning-system/#%e9%80%9a%e8%bf%87-im2col-%e6%9d%a5%e5%ae%9e%e7%8e%b0%e5%8d%b7%e7%a7%af-convolutions-via-im2col">《CMU 10-414 deep learning system》学习笔记 | 周鑫的个人博客</a>。</p>
<p>反向传播推导见博文：<a href="https://www.zhouxin.space/notes/2d-convolution-gradient-derivation-and-implementation/">2d 卷积梯度推导与实现 | 周鑫的个人博客</a></p>
<p>实现 Conv 的代码中使用了较多的 permute 重排操作，如果用 transpose 来实现重排太麻烦了，倒不如直接实现个重排的 TensorOp：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Permute</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axes</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">a</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">index</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">            <span class="n">index</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">index</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axes</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Permute</span><span class="p">(</span><span class="n">axes</span><span class="p">)(</span><span class="n">a</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>最终实现的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Conv</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">=</span> <span class="n">padding</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&#34;The input tensor should be 4D&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&#34;The kernel tensor should be 4D&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="n">A</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">B</span> <span class="o">=</span> <span class="n">B</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">in_height</span><span class="p">,</span> <span class="n">in_width</span><span class="p">,</span> <span class="n">in_channel</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">bs</span><span class="p">,</span> <span class="n">hs</span><span class="p">,</span> <span class="n">ws</span><span class="p">,</span> <span class="n">cs</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">strides</span>
</span></span><span class="line"><span class="cl">        <span class="n">kernel_height</span><span class="p">,</span> <span class="n">kernel_width</span><span class="p">,</span> <span class="n">in_channel</span><span class="p">,</span> <span class="n">out_channel</span> <span class="o">=</span> <span class="n">B</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="n">pad_A</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">pad</span><span class="p">(((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)))</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">in_height</span><span class="p">,</span> <span class="n">in_width</span><span class="p">,</span> <span class="n">in_channel</span> <span class="o">=</span> <span class="n">pad_A</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">bs</span><span class="p">,</span> <span class="n">hs</span><span class="p">,</span> <span class="n">ws</span><span class="p">,</span> <span class="n">cs</span> <span class="o">=</span> <span class="n">pad_A</span><span class="o">.</span><span class="n">strides</span>
</span></span><span class="line"><span class="cl">        <span class="n">receiptive_field_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">(</span><span class="n">in_height</span> <span class="o">-</span> <span class="n">kernel_height</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">in_width</span> <span class="o">-</span> <span class="n">kernel_width</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kernel_height</span><span class="p">,</span> <span class="n">kernel_width</span><span class="p">,</span> <span class="n">in_channel</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">receiptive_field_strides</span> <span class="o">=</span> <span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">hs</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="n">ws</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="n">hs</span><span class="p">,</span> <span class="n">ws</span><span class="p">,</span> <span class="n">cs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">receiptive_field</span> <span class="o">=</span> <span class="n">pad_A</span><span class="o">.</span><span class="n">as_strided</span><span class="p">(</span><span class="n">receiptive_field_shape</span><span class="p">,</span> <span class="n">receiptive_field_strides</span><span class="p">)</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">reveiptive_vector</span> <span class="o">=</span> <span class="n">receiptive_field</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">receiptive_field</span><span class="o">.</span><span class="n">size</span> <span class="o">//</span><span class="p">(</span><span class="n">kernel_height</span> <span class="o">*</span> <span class="n">kernel_width</span> <span class="o">*</span> <span class="n">in_channel</span><span class="p">),</span> <span class="n">kernel_height</span> <span class="o">*</span> <span class="n">kernel_width</span> <span class="o">*</span> <span class="n">in_channel</span><span class="p">))</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">kernel_vector</span> <span class="o">=</span> <span class="n">B</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">kernel_height</span> <span class="o">*</span> <span class="n">kernel_width</span> <span class="o">*</span> <span class="n">in_channel</span><span class="p">,</span> <span class="n">out_channel</span><span class="p">))</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">reveiptive_vector</span> <span class="o">@</span> <span class="n">kernel_vector</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">(</span><span class="n">in_height</span> <span class="o">-</span> <span class="n">kernel_height</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">in_width</span> <span class="o">-</span> <span class="n">kernel_width</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">out_channel</span><span class="p">))</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">X</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">inputs</span>
</span></span><span class="line"><span class="cl">        <span class="n">s</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">W</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1"># 计算X_grad</span>
</span></span><span class="line"><span class="cl">        <span class="n">W_flipped</span> <span class="o">=</span> <span class="n">flip</span><span class="p">(</span><span class="n">W</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">W_flipped_permuted</span> <span class="o">=</span> <span class="n">transpose</span><span class="p">(</span><span class="n">W_flipped</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># transpose 只支持两个维度的交换</span>
</span></span><span class="line"><span class="cl">        <span class="n">outgrad_dilated</span> <span class="o">=</span> <span class="n">dilate</span><span class="p">(</span><span class="n">out_grad</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_grad</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">outgrad_dilated</span><span class="p">,</span> <span class="n">W_flipped_permuted</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">s</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1"># 计算W_grad</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># outgrad_dilated = dilate(out_grad, (1, 2), self.stride - 1)</span>
</span></span><span class="line"><span class="cl">        <span class="n">outgrad_dilated_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">outgrad_dilated</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_permuted</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">W_grad</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">X_permuted</span><span class="p">,</span> <span class="n">outgrad_dilated_permuted</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">W_grad</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">W_grad</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">X_grad</span><span class="p">,</span> <span class="n">W_grad</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">conv</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Conv</span><span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">)(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>nn.Conv<br>
这里将实现一个卷积层。由如下要求：输入输出的格式为 (N,C,H,W)，padding 应满足当 stride=1 时，输出不缩水，支持 bias 项。</li>
</ul>
<p>首先修改 Kaming uniform 的实现，使之支持对卷积核的初始化。增加一个逻辑，根据参数 <code>shape</code> 是否为 None，在调用 rand 函数时传入不同的形状即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">kaiming_uniform</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="s2">&#34;relu&#34;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s2">&#34;relu&#34;</span><span class="p">,</span> <span class="s2">&#34;Only relu supported currently&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s2">&#34;relu&#34;</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">gain</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">bound</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">3</span> <span class="o">/</span> <span class="n">fan_in</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">rand</span><span class="p">(</span><span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">rand</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>hw4 的代码中，对于 <code>NDArray.sum</code> 的实现有问题，当求和的维度指定为空 tuple 时，其不应该进行求和操作，但原始代码无法正确处理这种情况，需要参数 axis 类型为 list 或者 tuple 的分支进行额外的判断，如果为空 list 或者 tuple，输出等于输入：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">axis</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">			<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span>
</span></span><span class="line"><span class="cl">		<span class="k">for</span> <span class="n">axis_</span> <span class="ow">in</span> <span class="n">axis</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">			<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis_</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	<span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">view</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_view_out</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">view</span><span class="o">.</span><span class="n">compact</span><span class="p">()</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">out</span><span class="o">.</span><span class="n">_handle</span><span class="p">,</span> <span class="n">view</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">out</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>万事俱备，卷积层的实现调用上边的函数即可。初始化的部分，根据文档描述初始化好权重和偏执项。对于步长为 1 的卷积，卷积结果会缩水 k-1 行 k-1 列，为了确保 shape 不变，卷积时四周要 pad (k-1)/2，又由于传统上 k 为奇数，因此等价于 pad k/2。</p>
<p>前向传播的部分，首先将 X 重排为 NHWC 的格式，然后加上卷积层。如果由偏执项，则将其广播后再加到结果中，最后将结果重排为 NCHW 格式返回即可。完整代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Conv</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">    Multi-channel 2D convolutional layer
</span></span></span><span class="line"><span class="cl"><span class="s2">    IMPORTANT: Accepts inputs in NCHW format, outputs also in NCHW format
</span></span></span><span class="line"><span class="cl"><span class="s2">    Only supports padding=same
</span></span></span><span class="line"><span class="cl"><span class="s2">    No grouped convolution or dilation
</span></span></span><span class="line"><span class="cl"><span class="s2">    Only supports square kernels
</span></span></span><span class="line"><span class="cl"><span class="s2">    &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">kernel_size</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">kernel_size</span> <span class="o">=</span> <span class="n">kernel_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">in_channels</span> <span class="o">=</span> <span class="n">in_channels</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">out_channels</span> <span class="o">=</span> <span class="n">out_channels</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="n">kernel_size</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">kaiming_uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">in_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">bias_bound</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">kernel_size</span> <span class="o">*</span> <span class="n">kernel_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bias_bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bias_bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">=</span> <span class="n">kernel_size</span> <span class="o">//</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="c1"># convert NCHW to NHWC</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="n">conv_x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">broadcasted_bias</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_channels</span><span class="p">)),</span> <span class="n">conv_x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">conv_x</span> <span class="o">=</span> <span class="n">conv_x</span> <span class="o">+</span> <span class="n">broadcasted_bias</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">conv_x</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>ResNet 9<br>
在实现 TensorOp 的子类时，如果需要初始化 Tensor，一定要指定 device。之前在实现 ReLU 生成 mask 时没有指定 device，将导致反向传播失败，这里对其进行修改：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">ReLU</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">node</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">relu_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">cached_data</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">node</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out_grad</span> <span class="o">*</span> <span class="n">relu_mask</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>同样，之前在实现 SoftmaxLoss 生成 one hot 时也没有指定 device，这里需要修改：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SoftmaxLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">label_size</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">one_hot_y</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">label_size</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">true_logits</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">summation</span><span class="p">(</span><span class="n">logits</span> <span class="o">*</span> <span class="n">one_hot_y</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span> <span class="o">-</span> <span class="n">true_logits</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">/</span><span class="n">batch_size</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>此外，还发现在 reshape 操作可能没有调用 compact，这里直接修改其实现，在调用 array_api 前进行 compact 操作：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Reshape</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">expect_size</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">expect_size</span> <span class="o">*=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">        <span class="n">real_size</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">real_size</span> <span class="o">*=</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">expect_size</span> <span class="o">==</span> <span class="n">real_size</span> <span class="p">,</span> <span class="s2">&#34;The reshape size is not compatible&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">array_api</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">compact</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>经过一番小修小补，我们的代码已经相当健壮，足以完成这个 ResNet 9🎉。ResNet 9 网络架构如下所示。写代码的过程中有些漏洞咱也没必要妄自菲薄，毕竟这么厉害的两位大佬也难免有笔误的地方。下图中的 ResNet 9 有一层网络架构写错了，已在原图中指出。<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409132043485.png?x-oss-process=image/quality,q_90/format,webp"><br>
首先来实现 ConvBN，传入的四个参数以此为 channels_in，channels_out，kernel_size 和 stride。hw4 的框架代码中提供了 BatchNorm2d，在拷贝 <code>nn_basic.py</code> 文件时不要直接覆盖。剩余的实现很简单，根据示意图搭积木，运行后哪里报 Not Implemented Error 就补哪里，完整代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">ResNet9</span><span class="p">(</span><span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">bias</span> <span class="o">=</span> <span class="kc">True</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION ###</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">ConvBN</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">ConvBN</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Residual</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="n">ConvBN</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                <span class="n">ConvBN</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">ConvBN</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">conv4</span> <span class="o">=</span> <span class="n">ConvBN</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">res2</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Residual</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="n">ConvBN</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                <span class="n">ConvBN</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">flatten</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv4</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>很遗憾，上述代码在我的设备上并不能通过 ResNet 9 的测试点，误差为 0.09，远超 tolerance 0.01。但其又能通过后续在 CIFAR 10 训练集上训练 2 epoches 的测试点，且误差为 5e-5，远小于 tolerance 0.01。怀疑前一个测试点数据有问题。</p>
<h2 id="part-4-recurrent-neural-network">Part 4: Recurrent neural network</h2>
<ul>
<li>RNN Cell<br>
RNN cell 似乎没有什么坑，照着文档初始化参数，照着公式进行正向传播：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">RNNCell</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="s1">&#39;tanh&#39;</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">bound</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">W_ih</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">W_hh</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">nonlinearity</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">tanh</span> <span class="k">if</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s1">&#39;tanh&#39;</span> <span class="k">else</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">h</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">h</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">W_hh</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z</span> <span class="o">=</span> <span class="n">X</span><span class="nd">@self.W_ih</span> <span class="o">+</span> <span class="n">h</span><span class="nd">@self.W_hh</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">Z</span> <span class="o">+=</span> <span class="n">bias</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">nonlinearity</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>RNN<br>
本节任务是完成一个多层 RNN，即堆叠在一起的 RNN，如下图所示。参数中 <code>num_layers</code> 指定了层数，<code>input_size</code> 指的是最下面那层 RNN 的输入的 x 的 size，除底层之外的 cell 的输入都是前一层的输入，即它们的 input_size = hidden_size<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409041908220.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<p>由上图，可知每一层的输入都是在变化的，因此考虑维护一个 <code>X_input</code> 列表用于存储当前没计算的 cell 的垂直输入。同样，维护一个 <code>h_input</code> 列表存储当前没计算的 cell 的水平输入。具体来说，当计算的 cell 编号为 $h_i^j$ 时，其用到的输入为 <code>X_input[i]</code> 和 <code>h_input[j]</code>，同时计算结束后 <code>X_input[j]</code> 和 <code>h_input[j]</code> 都要更新为该节点的输出。</p>
<p>对于这个堆叠在一起的 RNN，可以采用从左往右、从下到上，或者从下到上、从左往右的计算方式。我采用的是先垂直再水平的计算顺序。</p>
<p>模型最后要返回两个变量，一个是最后一层的输出 output，即示意图中的 y 的集合，不难发现最后一层的输出就是最后一层的后一层（假设存在）的垂直输入，即我们一直在维护的 <code>X_input</code>。另一个要返回的变量是最后一列隐藏层，同样，这就是我们一直在维护的水平输入 <code>h_input</code>。水到渠成。</p>
<p>需要注意，Tensor 没有实现 getitem 和 setitem 方法，需要切片存取的时候调用之前实现的 split 和 stack 方法即可。</p>
<p>完整代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">RNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="s1">&#39;tanh&#39;</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">RNNCell</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">RNNCell</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">h0</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">seq_len</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">layer_num</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">h0</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">h0</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span><span class="p">),</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">W_hh</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">h_input</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">h0</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="c1"># list length = num_layers, element shape = (bs, hidden_size)</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_input</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="c1"># list length = seq_len, element shape = (bs, input_size)</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">layer_num</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="n">X_input</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_cells</span><span class="p">[</span><span class="n">j</span><span class="p">](</span><span class="n">X_input</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">h_input</span><span class="p">[</span><span class="n">j</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">                <span class="n">h_input</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_input</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">output</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">X_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># output features of last layer == input X of last+1 layer</span>
</span></span><span class="line"><span class="cl">        <span class="n">h_n</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">h_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">h_n</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">            
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-5-lstm">Part 5: LSTM</h2>
<p>本章节将实现 LSTM，LSTM 和上边的 RNN 逻辑相同，照抄公式，这里直接放出代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span><span class="lnt">68
</span><span class="lnt">69
</span><span class="lnt">70
</span><span class="lnt">71
</span><span class="lnt">72
</span><span class="lnt">73
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">LSTMCell</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">bound</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">W_ih</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="mi">4</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">W_hh</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">4</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">4</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">4</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span> <span class="o">=</span> <span class="n">Sigmoid</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">bs</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">W_hh</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">h</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">h0</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">c0</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span> <span class="o">=</span> <span class="n">h</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z</span> <span class="o">=</span> <span class="n">X</span><span class="nd">@self.W_ih</span> <span class="o">+</span> <span class="n">h0</span><span class="nd">@self.W_hh</span> <span class="c1"># [bs, 4*hidden_size]</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">            <span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">Z</span> <span class="o">+=</span> <span class="n">bias</span>
</span></span><span class="line"><span class="cl">        <span class="n">stripes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">i</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stripes</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span> <span class="n">hidden_size</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">f</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stripes</span><span class="p">[</span><span class="n">hidden_size</span><span class="p">:</span> <span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">g</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stripes</span><span class="p">[</span><span class="mi">2</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">:</span> <span class="mi">3</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">o</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stripes</span><span class="p">[</span><span class="mi">3</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">:</span> <span class="mi">4</span><span class="o">*</span><span class="n">hidden_size</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">c</span> <span class="o">=</span> <span class="n">f</span> <span class="o">*</span> <span class="n">c0</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="n">g</span>
</span></span><span class="line"><span class="cl">        <span class="n">h</span> <span class="o">=</span> <span class="n">o</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">c</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">LSTM</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">lstm_cells</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">lstm_cells</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">lstm_cells</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">seq_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">num_layers</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lstm_cells</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm_cells</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">W_hh</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">h</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">h0</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">c0</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">X</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span> <span class="o">=</span> <span class="n">h</span>
</span></span><span class="line"><span class="cl">        <span class="n">h_input</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">h0</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">c_input</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">c0</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">X_input</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="n">X_input</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">c_input</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm_cells</span><span class="p">[</span><span class="n">j</span><span class="p">](</span><span class="n">X_input</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">(</span><span class="n">h_input</span><span class="p">[</span><span class="n">j</span><span class="p">],</span> <span class="n">c_input</span><span class="p">[</span><span class="n">j</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">                <span class="n">h_input</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">X_input</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">output</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">X_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">h_n</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">h_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">c_n</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">c_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="p">(</span><span class="n">h_n</span><span class="p">,</span> <span class="n">c_n</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-6-penn-treebank-dataset">Part 6: Penn Treebank dataset</h2>
<ul>
<li>Dictionary<br>
这个类的作用是构建一个从 word 到 id 双向映射的字典，word2idx 通过读取 <code>dict</code> 来实现，idx2word 通过访问 <code>list</code> 来实现：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Dictionary</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">word2idx</span> <span class="o">=</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">idx2word</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">add_word</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">word</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">word2idx</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">word</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">word2idx</span><span class="p">[</span><span class="n">word</span><span class="p">]</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">idx2word</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">idx2word</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">word</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">word2idx</span><span class="p">[</span><span class="n">word</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">idx2word</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>Corpus<br>
这个类的作用类似于 DataLoader，从文件读取原始数据，通过 <code>Dictionary</code> 将其 tokenize，提供 <code>batchify</code> 将其分割为 batch（这个 batch 指的是输入的 x 中同时存在好几个句子），提供 <code>get_batch</code> 方法将单个句子分割为 batch（这是由于 lstm 的水平深度有限，最多同时接受这么多输入）。</li>
</ul>
<p>具体实现时参考 docstring 描述即可，由示意图，一目了然。完整代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Corpus</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">base_dir</span><span class="p">,</span> <span class="n">max_lines</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dictionary</span> <span class="o">=</span> <span class="n">Dictionary</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">train</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">base_dir</span><span class="p">,</span> <span class="s1">&#39;train.txt&#39;</span><span class="p">),</span> <span class="n">max_lines</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">test</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">base_dir</span><span class="p">,</span> <span class="s1">&#39;test.txt&#39;</span><span class="p">),</span> <span class="n">max_lines</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">max_lines</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">ids</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">            <span class="n">line_idx</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">f</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="k">if</span> <span class="n">max_lines</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">line_idx</span> <span class="o">&gt;=</span> <span class="n">max_lines</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                    <span class="k">break</span>
</span></span><span class="line"><span class="cl">                <span class="n">words</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="o">+</span> <span class="p">[</span><span class="s1">&#39;&lt;eos&gt;&#39;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">words</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                    <span class="n">ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dictionary</span><span class="o">.</span><span class="n">add_word</span><span class="p">(</span><span class="n">word</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">                <span class="n">line_idx</span> <span class="o">+=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">ids</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">batchify</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">data_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">nbatch</span> <span class="o">=</span> <span class="n">data_len</span> <span class="o">//</span> <span class="n">batch_size</span>
</span></span><span class="line"><span class="cl">    <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[:</span><span class="n">nbatch</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">get_batch</span><span class="p">(</span><span class="n">batches</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">bptt</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">data</span> <span class="o">=</span> <span class="n">batches</span><span class="p">[</span><span class="n">i</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">bptt</span><span class="p">,</span> <span class="p">:]</span>
</span></span><span class="line"><span class="cl">    <span class="n">target</span> <span class="o">=</span> <span class="n">batches</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">bptt</span><span class="p">,</span> <span class="p">:]</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-7-training-a-word-level-language-model">Part 7: Training a word-level language model</h2>
<p>这里有个大坑，<code>ndarray</code> 实现的矩阵乘法不支持批量矩乘，如果由三维矩阵乘二维的情况，需要手动 reshape 再乘，再 reshape 回去。</p>
<ul>
<li>Embedding<br>
这个 Module 的作用是将 token 进行一次线性变换，这个操作涉及到批量矩乘：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Embedding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_embeddings</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">init</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">num_embeddings</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">one_hot</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">seq_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">num_embeddings</span> <span class="o">=</span> <span class="n">one_hot</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">one_hot</span> <span class="o">=</span> <span class="n">one_hot</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">seq_len</span><span class="o">*</span><span class="n">bs</span><span class="p">,</span> <span class="n">num_embeddings</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">one_hot</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>LanguageModel<br>
搭积木，同样设计批量矩乘：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">LanguageModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embedding_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                 <span class="n">seq_model</span><span class="o">=</span><span class="s1">&#39;rnn&#39;</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">(</span><span class="n">LanguageModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">output_size</span><span class="p">,</span> <span class="n">embedding_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">seq_model</span> <span class="o">==</span> <span class="s1">&#39;rnn&#39;</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">embedding_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">elif</span> <span class="n">seq_model</span> <span class="o">==</span> <span class="s1">&#39;lstm&#39;</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="n">embedding_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># (seq_len, bs, embedding_size)</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">seq_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">hidden_size</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">seq_len</span> <span class="o">*</span> <span class="n">bs</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">h</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>epoch_general_ptb<br>
流程和 hw2 中实现的 epoch 很接近，<code>iter_num = n_batch - seq_len</code> 是因为每条句子长度为 n_batch，按照 seq_len 的滑动窗口加载数据集，同时句子的最后一个词不能作为输入（后面没有输出了）。</li>
</ul>
<p>如果出现没有实现的异常，就从 hw2 中粘过来。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">epoch_general_ptb</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">seq_len</span><span class="o">=</span><span class="mi">40</span><span class="p">,</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">SoftmaxLoss</span><span class="p">(),</span> <span class="n">opt</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">clip</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="n">opt</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">total_loss</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">    <span class="n">total_error</span> <span class="o">=</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">    <span class="n">n_batch</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">iter_num</span> <span class="o">=</span> <span class="n">n_batch</span> <span class="o">-</span> <span class="n">seq_len</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">iter_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">iter_num</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">X</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">get_batch</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">iter_idx</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">opt</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">opt</span><span class="o">.</span><span class="n">reset_grad</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="n">opt</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">opt</span><span class="o">.</span><span class="n">reset_grad</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">            <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="n">clip</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">                <span class="n">opt</span><span class="o">.</span><span class="n">clip_grad_norm</span><span class="p">(</span><span class="n">clip</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">total_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="n">total_error</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">pred</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">!=</span><span class="n">target</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">    <span class="n">avg_loss</span> <span class="o">=</span> <span class="n">total_loss</span> <span class="o">/</span> <span class="n">iter_num</span>
</span></span><span class="line"><span class="cl">    <span class="n">avg_acc</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">total_error</span> <span class="o">/</span> <span class="p">(</span><span class="n">iter_num</span> <span class="o">*</span> <span class="n">seq_len</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">avg_acc</span><span class="p">,</span> <span class="n">avg_loss</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>train/evaluate ptb<br>
这里有个坑，这两个函数接受的损失函数传进来的是类，但是当我们要调用前面的 epoch 方法时要将其实例化。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train_ptb</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">seq_len</span><span class="o">=</span><span class="mi">40</span><span class="p">,</span> <span class="n">n_epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">ndl</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="n">lr</span><span class="o">=</span><span class="mf">4.0</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">SoftmaxLoss</span><span class="p">,</span> <span class="n">clip</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_epochs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">avg_acc</span><span class="p">,</span> <span class="n">avg_loss</span> <span class="o">=</span> <span class="n">epoch_general_ptb</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">(),</span> <span class="n">optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">),</span> <span class="n">clip</span><span class="o">=</span><span class="n">clip</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">avg_acc</span><span class="p">,</span> <span class="n">avg_loss</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">evaluate_ptb</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">seq_len</span><span class="o">=</span><span class="mi">40</span><span class="p">,</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">SoftmaxLoss</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">    <span class="n">avg_acc</span><span class="p">,</span> <span class="n">avg_loss</span> <span class="o">=</span> <span class="n">epoch_general_ptb</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">(),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">avg_acc</span><span class="p">,</span> <span class="n">avg_loss</span>
</span></span><span class="line"><span class="cl">    <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="hw4-小结">hw4 小结</h2>
<p>本节最大的难点在于卷积反向传播的推导，当时推导得头秃了。剩余内容基本都是在搭积木和对之前的实现小修小补，也挺烦躁。</p>
<p>总算是完结了，撒花🎉</p>
<h1 id="hw4_extra">hw4_extra</h1>
<p>Fine，还有一个实验，继续！</p>
<h2 id="part-1-implementing-the-multi-head-attention-activation-layer">Part 1: Implementing the Multi-Head Attention Activation Layer</h2>
<p>这部分将完成一个多头自注意层的正向传播部分。在这个类中提供了一系列辅助函数，记得先浏览一遍。</p>
<p>文档中有两点没有提到：</p>
<ul>
<li><code>self.causal</code> 决定了是否要进行掩码</li>
<li><code>self.matmul</code> 计算的是 <code>A@B.T</code> 而不是`A@B</li>
</ul>
<p>之前实现的 <code>dropout</code> 算子有点问题，没有指定 <code>dtype</code> 和 <code>device</code>，需要修改：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Dropout</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">return</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">        <span class="n">mask</span> <span class="o">=</span> <span class="n">init</span><span class="o">.</span><span class="n">randb</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">mask</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>由于输入的 KQV 在已经把“头”作为一个独立维度分离出来了，实现多头自注意力就简单很多，直接当作单头一样抄公式即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_head</span><span class="p">,</span> <span class="n">queries_len</span><span class="p">,</span> <span class="n">q_dim</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">keys_values_len</span><span class="p">,</span> <span class="n">k_dim</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">v_dim</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">assert</span> <span class="n">q_dim</span> <span class="o">==</span> <span class="n">k_dim</span> <span class="o">==</span> <span class="n">v_dim</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">result</span> <span class="o">=</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">        <span class="n">probs</span> <span class="o">=</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">sqrt_d</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">q_dim</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">/</span> <span class="n">sqrt_d</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">create_causal_mask</span><span class="p">(</span><span class="n">queries_len</span><span class="p">,</span> <span class="n">keys_values_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="n">Z</span> <span class="o">=</span> <span class="n">Z</span> <span class="o">+</span> <span class="n">mask</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">probs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">result</span><span class="p">,</span> <span class="n">probs</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-2-implementing-the-self-attention-layer-with-trainable-parameters">Part 2 Implementing the Self-Attention Layer with trainable parameters</h2>
<p>本部分将实现一个多头自注意力层，包括对 KQV 进行 preNorm、分头、调用之前实现的正向传播代码、合并、线性映射。</p>
<p>首先修改 <code>class Matmul</code> 的实现，使之支持当 A 为 batch 时的 batch matmul 计算：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">MatMul</span><span class="p">(</span><span class="n">TensorOp</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">a_shape</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="n">batch_size</span> <span class="o">*=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="n">a</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">a_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">        <span class="n">out</span> <span class="o">=</span> <span class="n">a</span><span class="nd">@b</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">a_shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="o">*</span><span class="n">a_shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">out</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>之前实现的 layerNorm1D 只支持 (batch_size, hddien_size) 的格式，在调用 perNorm 之前要手动进行 reshape，或者直接修改 layerNorm 的实现。</p>
<p>之前实现的 Linear 模块有点问题，当不存在 bias 时仍旧会尝试对其访问，需要修改：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">y</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">            <span class="n">y</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">y</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>分头行动就是先 reshape 再 permute，这一操作在前面的 hw 中已经出现多次，比较熟练。整体实现比较简单，不到十行代码即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">	<span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">	<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">v</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="n">k</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">k</span> <span class="o">=</span> <span class="n">q</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">v</span> <span class="o">=</span> <span class="n">q</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="n">batch_size</span><span class="p">,</span> <span class="n">queries_len</span><span class="p">,</span> <span class="n">q_dim</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">	<span class="n">_</span><span class="p">,</span> <span class="n">keys_values_len</span><span class="p">,</span> <span class="n">k_dim</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">	<span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">v_dim</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="n">result</span> <span class="o">=</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">	<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prenorm_q</span><span class="p">(</span><span class="n">q</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">prenorm_k</span><span class="p">(</span><span class="n">k</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">prenorm_v</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_projection</span><span class="p">(</span><span class="n">q</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">k_projection</span><span class="p">(</span><span class="n">k</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_projection</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">q</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">q</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">queries_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_head</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim_head</span><span class="p">)),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="n">k</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">k</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">keys_values_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_head</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim_head</span><span class="p">)),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="n">v</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">keys_values_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_head</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim_head</span><span class="p">)),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="n">attn_res</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="n">attn_res</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">attn_res</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">keys_values_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_head</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim_head</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_projection</span><span class="p">(</span><span class="n">attn_res</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">result</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-3-implementing-a-prenorm-residual-transformer-layer">Part 3 Implementing a prenorm residual Transformer Layer</h2>
<p>本节将完成一个残差 Transformer 层，本层没有难度，纯搭积木。搭积木之前照例对我们的积木块打个补丁，上个 Part 中修改的 Linear 层仍有问题，bias 不支持多 batch 维度，修改为一下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">	<span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">	<span class="n">y</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">		<span class="n">boradcast_shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">		<span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">boradcast_shape</span><span class="p">)</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="n">y</span> <span class="o">+=</span> <span class="n">bias</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">y</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来就可以愉快地搭积木啦：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">TransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">q_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">num_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dim_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="o">*</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dropout</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">causal</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">device</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dtype</span> <span class="o">=</span> <span class="s2">&#34;float32&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">AttentionLayer</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="n">q_features</span><span class="o">=</span><span class="n">q_features</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">num_head</span><span class="o">=</span><span class="n">num_head</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">dim_head</span><span class="o">=</span><span class="n">dim_head</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">out_features</span><span class="o">=</span><span class="n">q_features</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">causal</span><span class="o">=</span><span class="n">causal</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span>
</span></span><span class="line"><span class="cl">            <span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">LayerNorm1d</span><span class="p">(</span><span class="n">q_features</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="n">Linear</span><span class="p">(</span><span class="n">q_features</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="n">ReLU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">            <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">q_features</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">            <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">            
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">x_dim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="part-4-implementing-the-transformer-model">Part 4 Implementing the Transformer model</h2>
<p>本部分完成的是一个完整的 Transformer 网络。文档中提到，根据每个词在句子中的序号做一个 embed，所以在初始化时要额外初始化一个 embed 层，在数据进入 Transformer 前把这个 embed 加上去。其余部分搭积木：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Transformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">embedding_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> 
</span></span><span class="line"><span class="cl">        <span class="o">*</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">num_head</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dim_head</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dropout</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">causal</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">device</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">dtype</span> <span class="o">=</span> <span class="s2">&#34;float32&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_first</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">sequence_len</span> <span class="o">=</span> <span class="mi">2048</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span> <span class="o">=</span> <span class="n">batch_first</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">Embedding</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">num_embeddings</span><span class="o">=</span><span class="n">sequence_len</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">embedding_dim</span><span class="o">=</span><span class="n">embedding_size</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="n">TransformerLayer</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">q_features</span><span class="o">=</span><span class="n">embedding_size</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">num_head</span><span class="o">=</span><span class="n">num_head</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">dim_head</span><span class="o">=</span><span class="n">dim_head</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">causal</span><span class="o">=</span><span class="n">causal</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="kc">None</span>
</span></span><span class="line"><span class="cl">    <span class="p">):</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="c1">### BEGIN YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">        <span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">        <span class="n">time</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">seq_len</span><span class="p">),</span> <span class="n">bs</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">))</span><span class="o">.</span><span class="n">T</span>
</span></span><span class="line"><span class="cl">        <span class="n">time</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">time</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">time</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">time</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">time</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="c1">### END YOUR SOLUTION</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">init</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>由于 <code>ops.matmul</code> 中对于 batch matmul 的坑太多了，之前只修改了正向传播部分，反向传播仍未支持 matmul，最后没能实现在数据集上进行训练 Transformer 网络，略有遗憾。</p>
<h2 id="hw4_extra-小结">hw4_extra 小结</h2>
<p>hw4_extra 难度相比 hw4 低了很多，毕竟没让我们自己手推 Transformer 的反向传播公式，不然又是一场腥风血雨。</p>
<p>这次是真的完结了，撒花🎉</p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://zhuanlan.zhihu.com/p/579465666">zhuanlan.zhihu.com/p/579465666</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p><a href="https://github.com/woaixiaoxiao/CMU10414/tree/main/hw4">CMU10414/hw4 at main · woaixiaoxiao/CMU10414 · GitHub</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>在Hugo中使用KATEX渲染数学公式</title>
      <link>https://www.zhouxin.space/notes/using-katex-to-render-math-in-hugo/</link>
      <pubDate>Wed, 05 Jun 2024 15:35:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/using-katex-to-render-math-in-hugo/</guid>
      <description>&lt;h1 id=&#34;前言&#34;&gt;前言&lt;/h1&gt;
&lt;p&gt;在博文中插入公式是个挺常见的需求，不知道为啥 Hugo 对于公式渲染没有原生支持😞。网络上能找到两种解决方案：KATEX 和 MathJax，据说前者性能更好一点。本博客使用 KATEX 进行渲染。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="前言">前言</h1>
<p>在博文中插入公式是个挺常见的需求，不知道为啥 Hugo 对于公式渲染没有原生支持😞。网络上能找到两种解决方案：KATEX 和 MathJax，据说前者性能更好一点。本博客使用 KATEX 进行渲染。</p>
<p>网络上相关资料挺多，但大多浅尝辄止，我在将其整合进 Obsidian 的过程中遇到了不少错误，折腾了一个下午 + 一个晚上，目前终于跑通能用了。demo 参考博文：<a href="https://www.zhouxin.space/notes/notes-on-cmu-10-414-deep-learning-system/">《CMU 10-414 deep learning system》学习笔记 | 周鑫的个人博客</a>，其中含有大量公式。</p>
<h1 id="技术方案">技术方案</h1>
<p>目前含有数学公式的工作流为：<br>
Obsidian 编辑博文 -&gt; Obsidian github publisher 插件进行正则替换 -&gt; Obsidian github publisher 上传到 github -&gt; 服务器进行部署</p>
<h2 id="引入-katex-样式表和-js-文件">引入 KATEX 样式表和 JS 文件</h2>
<p>为了在博文中渲染公式，需要引入 KATEX 的样式表 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>，具体来说，在 <code>&lt;your_hugo_site&gt;/layouts/partials/</code> 文件夹下创建一个 <code>math.html</code> 文件，并写入以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">link</span>
</span></span><span class="line"><span class="cl">    <span class="na">rel</span><span class="o">=</span><span class="s">&#34;stylesheet&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="na">href</span><span class="o">=</span><span class="s">&#34;https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.css&#34;</span> 
</span></span><span class="line"><span class="cl">    <span class="na">integrity</span><span class="o">=</span><span class="s">&#34;sha384-wcIxkf4k558AjM3Yz3BBFQUbk/zgIYC2R0QpeeYb+TwlBVMrlgLqwRjRtGZiK7ww&#34;</span> 
</span></span><span class="line"><span class="cl">    <span class="na">crossorigin</span><span class="o">=</span><span class="s">&#34;anonymous&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">/&gt;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">script</span> <span class="na">defer</span> <span class="na">src</span><span class="o">=</span><span class="s">&#34;https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.js&#34;</span> <span class="na">integrity</span><span class="o">=</span><span class="s">&#34;sha384-hIoBPJpTUs74ddyc4bFZSM1TVlQDA60VBbJS0oA934VSz82sBx1X7kSx2ATBDIyd&#34;</span> <span class="na">crossorigin</span><span class="o">=</span><span class="s">&#34;anonymous&#34;</span><span class="p">&gt;&lt;/</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c">&lt;!-- To automatically render math in text elements, include the auto-render extension: --&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">script</span> <span class="na">defer</span> <span class="na">src</span><span class="o">=</span><span class="s">&#34;https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/contrib/auto-render.min.js&#34;</span> <span class="na">integrity</span><span class="o">=</span><span class="s">&#34;sha384-43gviWU0YVjaDtb/GhzOouOXtZMP/7XUzwPTstBeZFe/+rCMvRwr4yROQP43s0Xk&#34;</span> <span class="na">crossorigin</span><span class="o">=</span><span class="s">&#34;anonymous&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="na">onload</span><span class="o">=</span><span class="s">&#34;
</span></span></span><span class="line"><span class="cl"><span class="s">    window.addEventListener(&#39;DOMContentLoaded&#39;, function() {
</span></span></span><span class="line"><span class="cl"><span class="s">        renderMathInElement(document.body, {
</span></span></span><span class="line"><span class="cl"><span class="s">            delimiters: [
</span></span></span><span class="line"><span class="cl"><span class="s">                {left: &#39;$$&#39;, right: &#39;$$&#39;, display: true},
</span></span></span><span class="line"><span class="cl"><span class="s">                {left: &#39;$&#39;, right: &#39;$&#39;, display: false},
</span></span></span><span class="line"><span class="cl"><span class="s">                {left: &#39;\\$$&#39;, right: &#39;\\\\$$&#39;, display: false},
</span></span></span><span class="line"><span class="cl"><span class="s">                {left: &#39;\\$$&#39;, right: &#39;\\\\$$&#39;, display: true}
</span></span></span><span class="line"><span class="cl"><span class="s">            ]
</span></span></span><span class="line"><span class="cl"><span class="s">        });
</span></span></span><span class="line"><span class="cl"><span class="s">    });
</span></span></span><span class="line"><span class="cl"><span class="s">&#34;</span><span class="p">&gt;&lt;/</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后在 <code>&lt;your_hugo_site&gt;/layouts/partials/extend_head.html</code> 文件内追加以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl">{{ if or .Params.math .Site.Params.math }}
</span></span><span class="line"><span class="cl">{{ partial &#34;math.html&#34; . }}
</span></span><span class="line"><span class="cl">{{ end }}
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码的含义是：如果当前页面 <code>math</code> 属性或者全局 <code>math</code> 属性为真，则将我们之前写入的 <code>math.html</code> 模板文件包含至每个网页页面的 head 部分。</p>
<p>我们可以只将需要渲染数学公式的博文的 metadata 区域 <code>math</code> 字段设置为真，以引入 KATEX 相关文件，防止不必要的性能开销。</p>
<p>至此，理论上来说，含有数学公式的博文已经能被正确渲染了，许多教程也到此结束了。但是我碰到了公式没能被正确渲染的情况，如下图所示：<br>
<img alt="红框没公式渲染出错" loading="lazy" src="https://pics.zhouxin.space/202406051638640.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="与-obsidian-整合">与 Obsidian 整合</h2>
<p>上图中公式渲染出错的情况，有以下两个原因：</p>
<ul>
<li>Markdown 语法和 KATEX 语法冲突，包括但不限于：符号转义、下划线含义冲突等</li>
<li>公式块和公式内容之间存在额外空格</li>
</ul>
<p>第一个问题可以通过使用 <code>div</code> 块包围公式来解决，hugo 不会对 <code>div</code> 块内的代码进行二次转义。第二个问题可以通过正则表达式替换来解决。</p>
<p>事实上，第一个问题也是正则所擅长的领域，通过一次正则替换，就可以具体将 md 中的公式块使用 <code>div</code> 包裹，并且移除额外的空格。具体来说，需要将以下的 md 文档：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-md" data-lang="md"><span class="line"><span class="cl">梯度下降，就是沿着梯度方向不断进行迭代，以求找到最佳的$\theta$使得目标函数值最小。
</span></span><span class="line"><span class="cl">$$
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">\theta :=\theta _0-\alpha \nabla f\left( \theta _0 \right)
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">$$
</span></span><span class="line"><span class="cl">上式中，$\alpha$被称为学习率或者步长。
</span></span></code></pre></td></tr></table>
</div>
</div><p>替换为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-md" data-lang="md"><span class="line"><span class="cl">梯度下降，就是沿着梯度方向不断进行迭代，以求找到最佳的$\theta$使得目标函数值最小。
</span></span><span class="line"><span class="cl">&lt;div&gt;$$
</span></span><span class="line"><span class="cl">\theta :=\theta _0-\alpha \nabla f\left( \theta _0 \right)
</span></span><span class="line"><span class="cl">$$&lt;/div&gt;
</span></span><span class="line"><span class="cl">上式中，$\alpha$被称为学习率或者步长。
</span></span></code></pre></td></tr></table>
</div>
</div><p>相应的模式串为 <code>/\$\$(\s*)([\s\S]*?)(\s*)\$\$/gs</code>，替换串为 <code>&lt;div&gt;$$$$\n$2\n$$$$&lt;/div&gt;</code>。使用 github publisher 插件进行替换即可。</p>
<h1 id="one-more-thing">One More Thing</h1>
<p>推荐两个网站，分别用于 KATEX 和正则表达式的 debug：</p>
<ul>
<li><a href="https://katex.org/">KaTeX – The fastest math typesetting library for the web</a></li>
<li><a href="https://regex101.com/">regex101: build, test, and debug regex</a></li>
</ul>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://katex.org/docs/browser">Browser · KaTeX</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>《CMU 10-414 deep learning system》学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-cmu-10-414-deep-learning-system/</link>
      <pubDate>Tue, 28 May 2024 12:24:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-cmu-10-414-deep-learning-system/</guid>
      <description>&lt;h1 id=&#34;写在最前面&#34;&gt;写在最前面&lt;/h1&gt;
&lt;p&gt;从 2024-04-28 到 2024-09-08，历时四个多月，总算把 DLSys 学完了。这门课的一些收获：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;自动微分理论知识和在实践过程中衍生的包括计算图等知识&lt;/li&gt;
&lt;li&gt;系统学习了 ML 中几个基本模型和组件&lt;/li&gt;
&lt;li&gt;Tensor 的 strides 相关内容&lt;/li&gt;
&lt;li&gt;基础 CUDA 编程&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;个人认为这门课一些没达到我预期的地方：&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="写在最前面">写在最前面</h1>
<p>从 2024-04-28 到 2024-09-08，历时四个多月，总算把 DLSys 学完了。这门课的一些收获：</p>
<ul>
<li>自动微分理论知识和在实践过程中衍生的包括计算图等知识</li>
<li>系统学习了 ML 中几个基本模型和组件</li>
<li>Tensor 的 strides 相关内容</li>
<li>基础 CUDA 编程</li>
</ul>
<p>个人认为这门课一些没达到我预期的地方：</p>
<ul>
<li>CUDA 编程的内容太浅</li>
<li>后续讲 CNN、RNN、Transformer 的部分没必要，可以继续深入 CUDA 或者压缩课时</li>
</ul>
<p>本门课程的核心内容在 Lecture 0<del>15，对应的 homework 是 hw0</del>3，后面的内容没有时间可以跳过。</p>
<p>ps：全文章两万余字，Chrome 渲染图片时可能会很卡，建议使用 Microsoft Edge 浏览。</p>
<h1 id="lecture-1-introduction-and-logistics">Lecture 1: Introduction and Logistics</h1>
<h2 id="课程的目标">课程的目标</h2>
<p>本课程的目标是学习现代深度学习系统，了解包括自动微分、神经网络架构、优化以及 GPU 上的高效操作在内的技术的底层原理。作为实践，本课程将实现一个 needle（deep learning library）库，类似 PyTorch。</p>
<h2 id="为什么学习深度学习系统">为什么学习深度学习系统？</h2>
<p>为什么学习？深度学习这一概念很早就存在了，但直到 PyTorch、TensorFlow 此类现代深度学习框架发布，深度学习才开始迅速发展。简单易用的自动差分库是深度学习发展的最大动力。</p>
<p>除了使用这些库，我们为什么还要学习深度学习系统？</p>
<ul>
<li>
<p>为了构建深度学习系统<br>
如果想要从事深度学习系统的开发，那毫无疑问得先学习它。目前深度学习框架并没完全成熟，还有很多开发新功能，乃至新的框架的机会。</p>
</li>
<li>
<p>为了能够更高效地使用现有系统<br>
了解现有系统的内部实现，可以帮助我们写出更加高效的深度学习代码。如果想要提高自定义算子的效率，那必须先了解相关操作是如何实现的。</p>
</li>
<li>
<p>深度学习系统本身就很有趣<br>
尽管这个系统看上去很复杂，但是其核心算法的原理确实相当简单的。两千行左右的代码，就可以写出一个深度学习库。</p>
</li>
</ul>
<h2 id="预备知识">预备知识</h2>
<ul>
<li>systems programming</li>
<li>线性代数</li>
<li>其他数学知识：计算、概率、简单的证明</li>
<li>Python 和 C++ 经验</li>
<li>机器学习的相关经验</li>
</ul>
<h1 id="lecture-2-ml-refresher--softmax-regression">Lecture 2: ML Refresher &amp; Softmax Regression</h1>
<h2 id="机器学习基础">机器学习基础</h2>
<p>深度学习是由数据驱动的，所谓数据驱动，这意味着当我们想要写一个用于识别手写数字的模型时，我们关注的不是某个数字形状上有什么特点，如何通过编程识别该特点，而是直接将数据集喂给模型，模型自动训练并识别数字类别。</p>
<p>深度学习模型由三部分组成：</p>
<ul>
<li>假说模型：模型的结构，包括一系列参数，其描述了模型从输入到输出的映射关系；</li>
<li>损失函数：指定了对模型的评价，损失函数值越小，说明该模型在指定任务上完成得更好；</li>
<li>优化方法：用于对模型中参数进行优化，使得损失函数最小的方法。</li>
</ul>
<h2 id="softmax-回归">Softmax 回归</h2>
<p>以经典的 softmax 回归模型为例，简单回顾一下 ML 模型。</p>
<p>考虑一个 k 分类任务，其中数据集为 $x^{(i)} \in R^n\ ,\  y^{(i)} \in { 1,&hellip;,k}\   \ \  i = 1,&hellip;,m$，$n$ 标识输入数据集的维度，$k$ 标识标签类别数，$m$ 标识数据集样本数量。</p>
<p>一个假说模型就是将一个 $n$ 维的输入映射到一个 $k$ 维的输出，即：$h: R^n \rightarrow R^k$。注意，模型并不会直接输出类别的序号，而是通过输出一个 $k$ 维向量 $h(x)$，其中第 $i$ 个元素 $h_i(x)$ 表示是第 $i$ 个类别的概率。</p>
<p>对于线性模型来说，使用 $\theta \in R^{n\times k}$ 这个模型中的参数，那么 $h_\theta(x) = \theta^T x$。</p>
<p>如果一次输入多个数据，那么输入数据就可以组织成一个矩阵，相比起多个向量操作，矩阵的操作通常效率更高，我们在代码实现中一般也是用矩阵操作。数据集可以表示为：</p>


<div>$$

X\in R^{m\times n}=\left[ \begin{array}{c}
	x^{(1)T}\\
	\vdots\\
	x^{\left( m \right) T}\\
\end{array} \right] ,  y\in \left\{ 1,...,k \right\} ^m=\left[ \begin{array}{c}
	y^{\left( 1 \right)}\\
	\vdots\\
	y^{\left( m \right)}\\
\end{array} \right]

$$</div>

<p>数据集的矩阵是一个个样本转置后堆叠 stack 起来的。那么输出可以表示为：</p>


<div>$$

h_{\theta}\left( X \right) =\left[ \begin{array}{c}
	h_{\theta}\left( x^{\left( 1 \right)} \right) ^T\\
	\vdots\\
	h_{\theta}\left( x^{\left( m \right)} \right) ^T\\
\end{array} \right] =\left[ \begin{array}{c}
	x^{\left( 1 \right) T}\theta\\
	\vdots\\
	x^{\left( m \right) T}\theta\\
\end{array} \right] =X\theta

$$</div>

<p>关于损失函数 $l_{err}$，一种朴素的想法是将模型预测错误的模型数据量作为损失函数，即如果模型预测的正确率最高的那个类别与真实类别不相同，则损失函数为 1，否则为 0：</p>


<div>$$

l_{err}\left( h\left( x \right) , y \right) \,\,=\,\,\left\{ \begin{align*}
	0 \ &amp;\mathrm{if} \ \mathrm{argmax} _i\,\,h_i\left( x \right) =y\\
	1 \ &amp;\mathrm{otherwise}\\
\end{align*} \right.

$$</div>

<p>遗憾的是，这个符合直觉函数是不可微分的，难以对参数进行优化。更合适的做法是使用交叉熵损失函数。</p>
<p>在此之前，我们将先讲输出过一个 softmax 函数，使之的行为更像一个概率——各个类别的概率之和为 1：</p>


<div>$$

z_i=p\left( \mathrm{label}=i \right) =\frac{\exp \left( h_i\left( x \right) \right)}{\sum_{j=1}^k{\exp \left( h_j\left( x \right) \right)}}

$$</div>

<p>那么交叉熵损失函数就可以定义为：</p>


<div>$$

l_{ce}\left( h\left( x \right) ,y \right) =-\log p\left( \mathrm{label}=y \right) =-h_y\left( x \right) &#43;\log \sum_{j=1}^k{\exp \left( h_j\left( x \right) \right)}

$$</div>

<p>注意在计算交叉熵时，通过运算进行了化简，这使得我们可以省去计算 softmax 的过程，直接计算最终的结果。不但如此，交叉熵的计算中，如果 $h_i(x)$ 的值很小，那么取对数会出现很大的值，化简后的计算则避免了这种情况。</p>
<p>所有的深度学习问题，都可以归结为一下这个最优化问题：

<div>$$

\mathop {\mathrm{minimize}} \limits_{\theta}\ \ \frac{1}{m}\sum_{i=1}^m{l(h_{\theta}(x^{(i)}),y^{(i)}))}

$$</div>

我们使用梯度下降法对该问题进行优化。在此之前，首先介绍一下关于梯度。我们的优化目标可以看作一个关于$\theta \in R^{n\times k}$的函数$f$，那么其在$\theta_0$处的梯度可以表示为：


<div>$$

\nabla _{\theta}f\left( \theta _0 \right) \in R^{n\times k}=\left[ \begin{matrix}  
	\frac{\partial f\left( \theta _0 \right)}{\partial \theta _{11}}&amp;		\cdots&amp;		\frac{\partial f\left( \theta _0 \right)}{\partial \theta _{k1}}\\  
	\vdots&amp;		\ddots&amp;		\vdots\\  
	\frac{\partial f\left( \theta _0 \right)}{\partial \theta _{n1}}&amp;		\cdots&amp;		\frac{\partial f\left( \theta _0 \right)}{\partial \theta _{nk}}\\  
\end{matrix} \right]

$$</div>

其中，第$i$行第$j$个元素表示除$\theta_{ij}$之外的参数都被当作常数，对$\theta_{ij}$求偏导。</p>
<p>梯度下降，就是沿着梯度方向不断进行迭代，以求找到最佳的$\theta$使得目标函数值最小。


<div>$$

\theta :=\theta _0-\alpha \nabla f\left( \theta _0 \right)

$$</div>

上式中，$\alpha$被称为学习率或者步长。</p>
<p>事实上，在现代深度学习中，并不是使用的传统梯度下降的方案，因为其无法将所有训练集一次性读入并计算梯度。现代使用的是随机梯度下降（Stochastic Gradient Descent，SGD）</p>
<p>首先将m个训练集样本划分一个个小batch，每个batch都有B条数据。那每一batch的数据表示为$X\in R^{B\times n}$，更新参数$\theta$的公式变为：


<div>$$

\theta :=\theta _0-\frac{\alpha}{B}\nabla f\left( \theta _0 \right)

$$</div>

我们的梯度变成了每个小batch对全体样本梯度的估计。</p>
<p>那如何计算梯度表达式呢？梯度矩阵中每个元素都是一个偏导数，我们就先从计算偏导数开始。假设$h$是个向量，我们来计算偏导数$\frac{\partial l_{ce}\left( h,y \right)}{\partial h_i}$：


<div>$$

\begin{align*}  
\frac{\partial l_{ce}\left( h,y \right)}{\partial h_i}&amp;=\frac{\partial}{\partial h_i}\left( -h_y&#43;\log \sum_{j=1}^k{\exp h_j} \right)  
\\  
&amp;=-1\left\{ i=y \right\} &#43;\frac{\exp \left( h_j \right)}{\sum_{j=1}^k{\exp h_j}}  
\\  
&amp;=-1\left\{ i=y \right\} &#43;\mathrm{softmax} \left( h \right)  
\\  
&amp;=z-e_y  
\end{align*}

$$</div>
</p>
<p>如果$h$是个向量，那么梯度$\nabla_h l_{ce}(h,y)$就能够以向量的形式表示为：


<div>$$

\nabla_h l_{ce}(h,y) = z-e_y

$$</div>

这里我们将对$h$进行softmax标准化记为$z$，$e_y$表示对应的单位向量。</p>
<p>事实上，我们要计算的梯度是关于$\theta$的，具体来说，表达式为$\nabla_\theta l_{ce}(\theta^Tx,y)$，其中，$\theta$是个矩阵。或许，可以使用链式法则进行求解，但是太麻烦了，这里还涉及矩阵对向量的求导。我们需要一种更加通用的求导方案。</p>
<p>有两个解决办法：</p>
<ul>
<li>正确且官方的做法：使用矩阵微分学、雅可比矩阵、克罗内克积和向量化等知识进行求解。</li>
<li>一个hacky、登不上台面、但大家都在用的方案：将所有的矩阵和向量当作标量，使用链式法则求解，并进行转置操作使得结果的size符合预期，最后检查数值上结果是否正确。</li>
</ul>
<p>按照第二个方法的逻辑，过程为：


<div>$$

\begin{align*}  
\frac{\partial}{\partial \theta}l_{ce}\left( \theta ^Tx,y \right) &amp;=\frac{\partial l_{ce}\left( \theta ^Tx,y \right)}{\partial \theta ^Tx}\cdot \frac{\partial \theta ^Tx}{\partial \theta}  
\\  
&amp;=\left[ z-e_y \right] _{k\times 1}\cdot x_{n\times 1}  
\\  
&amp;=x\cdot \left[ z-e_y \right]  
\end{align*}

$$</div>

其中，$z=\text{softmax}(\theta^Tx)$。注意，倒数第二步求出的结果是两个列向量相乘，不能运算。又已知结果应该是$n\times k$的矩阵，调整向量之间的顺序即可。</p>
<p>照猫画虎，可以写出batch的情况，$X\in R^{B\times n}$：


<div>$$

\begin{align*}  
\frac{\partial}{\partial \theta}l_{ce}\left( \theta ^TX,y \right) &amp;=\frac{\partial l_{ce}\left( \theta ^TX,y \right)}{\partial \theta ^TX}\cdot \frac{\partial \theta ^TX}{\partial \theta}  
\\  
&amp;=\left[ Z-E_y \right] _{B\times k}\cdot X_{B\times n}  
\\  
&amp;=X^T\cdot \left[ Z-E_y \right]  
\end{align*}

$$</div>
</p>
<h1 id="lecture-3-manual-neural-networks">Lecture 3: Manual Neural Networks</h1>
<p>这节课，我们将人工实现全连接神经网络，之后的课程，将引入自动微分技术。</p>
<h2 id="从线性模型转变为非线性模型">从线性模型转变为非线性模型</h2>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406071816708.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>如上图所示，线性模型本质上是将样本空间划分为线性的几个部分，这样的模型性能十分有限，因此很多不满足这样分布的实际问题就不能被解决。</p>
<p>一种解决方案是，在将样本输入到线性分类器前，先人工挑选出某些特征，即对$X$应用一个函数$\phi$，其将$X$映射到$\phi(X)$上，映射后的空间可以被线性划分。一方面，它确实是早期实践中行之有效的方案；另一方面，人工提取特征的泛化性能有限，受限于具体问题和研究人员的对问题的洞察程度。</p>
<p>如果我们使用线性网络提取特征，并直接接上一个线性分类头，这两个线性层等效为一个线性层，并不能做到非线性化的要求（基础知识，此处不再解释）。</p>
<p>因此，在使用线性网络提取特征后，需要再接上一个非线性函数$\sigma$，即$\phi = \sigma (W^T X)$。</p>
<h2 id="神经网络">神经网络</h2>
<p>上文提到的使用非线性函数后的模型，就可以视作一种最简单的神经网络。所谓神经网络，值得是机器学习中某一类特定的假说模型，其由多层组成，每一层都有大量可以微分的参数。</p>
<p>神经网络最初的确起源于模拟人类神经元这一动机，但随着其发展，越来越多的神经网络模型出现，与人类大脑神经网络越来越不相关。</p>
<p>以双层神经网络为例，其形式化表示为$h_\theta(x) = W_2^T \sigma(W_1^T x)$，所有可学习的参数使用$\theta$表示。以batch的矩阵形式表示为：


<div>$$

h_\theta(X) = \sigma(XW_1)W_2

$$</div>

接下来给出L层多层感知机（a.k.a. MLP、前馈神经网络、全连接层）的形式化表达：


<div>$$

\left\{\begin{array}{l}  
Z_{i&#43;1} = \sigma_i(Z_iW_i), i=1,...,L  \\  
Z_1 = X\\  
h_\theta(X) =Z_{L&#43;1}\\  
[Z_i\in R^{m\times n_i}, W_i \in R^{n_i\times n_{i&#43;1}}]\\  
\sigma_i:R\rightarrow R
\end{array} \right.

$$</div>

每一层的输入为$Z_i$，输出为$Z_{i+1}$。</p>
<p>为什么要是用深度网络而不是宽度网络？没有很完美的解释，但最好并且最现实的解释是：经验证明，当参数量固定时，深度网络性能优于宽度网络。</p>
<h2 id="反向传播梯度计算">反向传播（梯度计算）</h2>
<p>与Lecture 2一致，使用交叉熵作为损失函数，使用SGD作为优化算法，唯一的区别是，这次要对MLP网络求解梯度。</p>
<p>对于两层神经网络$h_\theta(X) = \sigma(XW_1)W_2$，待求的梯度表达式为：


<div>$$

\nabla_{\{W_1, W_2\}}l_{ce}(\sigma(XW_1)W_2,y)

$$</div>

对于$W_2$的梯度，其与Lecture 2的计算类似：


<div>$$

\begin{align*}  
\frac{\partial l_{ce}(\sigma(XW_1)W_2,y)}{\partial W_2}&amp;=\frac{\partial l_{ce}(\sigma(XW_1)W_2,y)}{\partial \sigma(XW_1)W_2} \cdot \frac{\partial\sigma(XW_1)W_2}{\partial W_2}\\  
&amp;=(S-I_y)_{m\times k}\cdot \sigma(XW_1)_{m\times d}\\  
&amp;=\sigma(XW_1)^T\cdot (S-I_y)\\  
&amp;[S=\text{softmax}(\sigma(XW_1))]  
\end{align*}

$$</div>
</p>
<p>对于$W_1$的梯度，其需要多次应用链式法则，但并不难计算：


<div>$$

\begin{align*}  
\frac{\partial l_{ce}(\sigma(XW_1)W_2,y)}{\partial W_1}&amp;=\frac{\partial l_{ce}(\sigma(XW_1)W_2,y)}{\partial \sigma(XW_1)W_2} \cdot \frac{\partial\sigma(XW_1)W_2}{\partial \sigma(XW_1)}\cdot \frac{\partial \sigma(XW_1)}{\partial XW_1}\cdot\frac{\partial XW_1}{\partial X_1}\\  
&amp;=(S-I_y)_{m\times k}\cdot [W_2]_{d\times k}\cdot \sigma\prime(XW_1)_{m\times d}\cdot X_{m\times n}\\  
&amp;=X^T\cdot [\sigma\prime(XW_1)\odot((S-I_y)\cdot W_2^T)]\\  
&amp;[S=\text{softmax}(\sigma(XW_1))]  
\end{align*}

$$</div>

以上公式中$\odot$表示逐元素乘法。至于为啥这么算，俺也不知道。</p>
<p>接下来将其推广到一般情况，即$L$层的MLP中对$W_i$求导：


<div>$$

\begin{align*}  
\frac{\partial l(Z_{l&#43;1},y)}{\partial W_i} &amp;=\frac{\partial l}{\partial Z_{l&#43;1}}\cdot \frac{\partial Z_{l&#43;1}}{\partial Z_{l}}\cdot...\cdot \frac{\partial Z_{i&#43;2}}{\partial Z_{i&#43;1}}\cdot\frac{\partial Z_{i&#43;1}}{\partial W_{i}}\\  
&amp;=G_{i&#43;1}\cdot\frac{\partial Z_{i&#43;1}}{\partial W_{i}}=\frac{\partial l}{\partial Z_{i&#43;1}}\cdot \frac{\partial Z_{i&#43;1}}{W_i}\\
\end{align*}

$$</div>
</p>
<p>由上述公式，我们可以得到一个反向迭代计算的$G_i$，即：


<div>$$

\begin{align*}  
G_i &amp;= G_{i&#43;1}\cdot \frac{Z_{i&#43;1}}{Z_i} \\  
&amp;=G_{i&#43;1}\cdot \frac{\partial \sigma(Z_iW_i)}{\partial Z_iW_i}\cdot\frac{\partial Z_iW_i}{Z_i}\\  
&amp;=G_{i&#43;1}\cdot \sigma\prime(Z_iW_i)\cdot W_i\\  
\end{align*}

$$</div>
</p>
<p>上面的计算都是将矩阵当作标量进行的，接下来我们考虑其维度。已知，$Z_i \in R^{m\times n_i}$是第$i$层的输入，$G_i = \frac{\partial l}{\partial Z_{i}}$，其维度如何呢？$G_i$每个元素表示损失函数$l$对第$i$层输入的每一项求偏导，也可以记作是$l$对$Z_i$求梯度，即$\nabla_{Z_i} l$，其维度显然是$m\times n_i$，继续计算前文$G_i$：


<div>$$

\begin{align*}  
G_i &amp;=[G_{i&#43;1}]_{m\times n_{i&#43;1}}\cdot \sigma\prime(Z_iW_i)_{m\times n_{i&#43;1}}\cdot [W_i]_{n_i\times n_{i&#43;1}}\\  
&amp;= [G_{i&#43;1}\odot \sigma\prime(Z_iW_i)]W_i^T  
\end{align*}

$$</div>
</p>
<p>有了$G_i$，就可以继续计算$l$对$W_i$的偏导数了：


<div>$$

\begin{align*}  
\frac{\partial l(Z_{l&#43;1},y)}{\partial W_i} &amp;=G_{i&#43;1}\cdot\frac{\partial Z_{i&#43;1}}{\partial W_{i}} \\  
&amp;=G_{i&#43;1}\cdot \frac{\partial\sigma(Z_iW_i)}{\partial Z_iW_i}\cdot\frac{\partial Z_iW_i}{\partial W_i}\\  
&amp;=[G_{i&#43;1}]_{m\times n_{i&#43;1}}\cdot \sigma\prime(Z_iW_i)_{m\times n_{i&#43;1}}\cdot [Z_i]_{m\times n_i}\\  
&amp;=Z_i^T\cdot[G_{i&#43;1}\odot\sigma\prime(Z_iW_i)]  
\end{align*}

$$</div>
</p>
<p>至此，每个小组件都已制造完毕，让我们来把它装起来吧！</p>
<ul>
<li>前向传播
<ul>
<li>初始化：$Z_1 = X$</li>
<li>迭代：$Z_{i+1} = \sigma(Z_iW_i)$ 直至$i=L$（注意，最后一层没有非线性部分，此处没有展示出来）</li>
</ul>
</li>
<li>反向传播
<ul>
<li>初始化：$G_{L+1} = S-I_y$</li>
<li>迭代：$G_i=[G_{i+1}\odot \sigma\prime(Z_iW_i)]W_i^T$ 直至$i=1$
值得注意的是，在反向传播中，需要用到前向传播的中间结果$Z_i$。为了更高效地计算梯度，不得不以牺牲内存空间为代价，即空间换时间。</li>
</ul>
</li>
</ul>
<blockquote>
<p>许多课程，讲到这里就结束了，但对我们这门课来说，才刚刚开始&hellip;</p>
</blockquote>
<h1 id="lecture-4-automatic-differentiation">Lecture 4: Automatic Differentiation</h1>
<h2 id="基本工具">基本工具</h2>
<ul>
<li>计算图
计算图是自动微分中常用的一种工具。计算图是一张有向无环图，每个节点表示（中间结果）值，每条边表示输入输出变量。例如，$y=f(x_1, x_2) = \ln(x_1)+x_1x_2-\sin x_2$对应的计算图为：
<img loading="lazy" src="https://pics.zhouxin.space/202406071612073.webp?x-oss-process=image/quality,q_90/format,webp">
按照拓扑序列遍历这张图，就可以得到对应表达式的值。</li>
</ul>
<h2 id="对自动微分方法的简单介绍">对自动微分方法的简单介绍</h2>
<p>深度学习中，一个核心内容就是计算梯度。这里介绍集中计算梯度的方案：</p>
<ul>
<li>偏导数定义</li>
<li></li>
</ul>
<p>梯度是由一个个偏导数组成的，可以直接根据偏导数的定义来计算梯度：


<div>$$

\frac{\partial f(\theta)}{\partial \theta_i} = \lim_{\epsilon \to 0}\frac{f(\theta &#43; \epsilon e_i) - f(\theta)}{\epsilon}

$$</div>

其中，$e_i$是表示第$i$个方向上的单位向量。</p>
<ul>
<li>数值求解
根据上述定义，我们可以选取一个很小的量代入$\epsilon$，得到数值计算偏导的方法：


<div>$$

\frac{\partial f(\theta)}{\partial \theta_i} = \frac{f(\theta &#43; \epsilon e_i) - f(\theta - \epsilon e_i)}{2\epsilon} &#43; o(\epsilon^2)

$$</div>

这里并不是直接使用第一项的公式，即分子不是$f(\theta + \epsilon e_i) - f(\theta)$，并且误差项是$\epsilon^2$，这是由于泰勒展开：


<div>$$

\begin{align*}  
f(\theta&#43;\delta) = f(\theta)&#43;f^\prime (\theta)\delta&#43;\frac{1}{2}f^{\prime \prime}(\theta)\delta^2&#43;o(\delta^3)\\  
f(\theta-\delta) = f(\theta)&#43;f^\prime (\theta)\delta-\frac{1}{2}f^{\prime \prime}(\theta)\delta^2&#43;o(\delta^3)  
\end{align*}

$$</div>

上述两式作差，即可得到数值计算$f^\prime(\theta)$的方法。</li>
</ul>
<p>这个方法的问题在于存在误差，并且效率低下（这里要计算两次f），该方法常用于验证其它方法的具体实现是否出错。具体来说，验证如下等式是否成立：


<div>$$

\delta^T \nabla_\theta f(\theta) = \frac{f(\theta &#43; \epsilon \delta) - f(\theta - \epsilon \delta)}{2 \epsilon} &#43; o(\epsilon^2)

$$</div>

其中$\delta$是单位球上的某个向量，$\nabla_\theta f(\theta)$是使用其它方法计算得到的梯度。等式左边是其它方法计算的梯度在$\delta$上的投影，右侧是使用数值求解得到的梯度值，验证该等式是否成立就可以判断左侧梯度是否计算错误。</p>
<ul>
<li>
<p>符号微分
符号微分，就是根据微分的计算规则使用符号手动计算微分。部分规则为：


<div>$$

\begin{align*}  
&amp;\frac{\partial (f(\theta) &#43; g(\theta))}{\partial \theta} = \frac{\partial f(\theta)}{\partial \theta} &#43; \frac{\partial g(\theta)}{\partial \theta}\\  
&amp;\frac{\partial (f(\theta) g(\theta))}{\partial \theta} = g(\theta) \frac{\partial f(\theta)}{\partial \theta} &#43; f(\theta) \frac{\partial g(\theta)}{\partial \theta}\\  
&amp;\frac{\partial f(g(\theta))}{\partial\theta}=\frac{\partial f(g(\theta))}{\partial g(\theta)}\frac{\partial g(\theta)}{\partial\theta}  
\end{align*}

$$</div>

根据该公式，可以计算得到$f(\theta) = \prod_{i=1}^{n} \theta_i$的梯度表达式为：$\frac{\partial f(\theta)}{\partial \theta_k} = \prod_{j \neq k}^{n} \theta_j$。如果我们根据该公式来计算梯度，会发现需要计算$n(n-2)$次乘法才能得到结果。这是因为在符号运算的过程中，我们忽略了可以反复利用的中间结果。</p>
</li>
<li>
<p>正向模式自动微分 forward mode automatic differentiation
沿着计算图的拓扑序列，同样可以计算出输出关于输入的导数，还是以$y=f(x_1, x_2) = \ln(x_1)+x_1x_2-\sin x_2$为例，其计算图为：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406071612328.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
</ul>
<p>整个梯度计算过程如下，在此过程中应用到了具体函数的求导公式：


<div>$$

\begin{align*}  
&amp;x_1 = 2\\  
&amp;x_2 = 5\\  
&amp;\dot v_{1} =1 \\  
&amp;\dot v_{2} =0 \\  
&amp;\dot{v}_{3} =\dot v_{1}/v_{1}=0.5 \\  
&amp;\dot{v}_{4} =\dot{v}_{1}v_{2}&#43;\dot v_{2}v_{1}=1\times5&#43;0\times2=5 \\  
&amp;\dot v{5} =\dot{v_{2}}\cos v_{2}=0\times\cos5=0 \\  
&amp;\dot{v}_{6} =\dot v_{3}&#43;\dot v_{4}=0.5&#43;5=5.5 \\  
&amp;\dot{v}_{7} =\dot{v_{6}}-\dot{v_{5}}=5.5-0=5.5  
\end{align*}

$$</div>
</p>
<p>对于$f:\mathbb{R}^n \to \mathbb{R}^k$，前向传播需要$n$次前向计算才能得到关于每个输入的梯度，这就意味前向传播适合$n$比较小、$k$比较大的情况。但是在深度学习中，通常$n$比较大、$k$比较小。</p>
<ul>
<li>反向模式自动微分
定义$\text{adjoint}:\overline{v_i}=\frac{\partial y}{\partial v_i}$,其表示损失函数对于参数$v_i$的偏导。
整个计算过程如下所示，需要注意的是$\overline{v_2}$的计算过程，其在计算图上延伸出了两个节点，因此梯度也由两部分相加：


<div>$$

\begin{align*}  
&amp;\overline{v_{7}}=\frac{\partial y}{\partial v_{7}}=1\\  
&amp;\overline{v_{6}}=\overline{v_{7}}\frac{\partial v_{7}}{\partial v_{6}}=\overline{v_{7}}\times1=1\\  
&amp;\overline{v_{5}}=\overline{v_{7}}\frac{\partial v_{7}}{\partial v_{5}}=\overline{v_{7}}\times(-1)=-1\\  
&amp;\overline{v_{4}}=\overline{v_{6}}\frac{\partial v_{6}}{\partial v_{4}}=\overline{v_{6}}\times1=1\\  
&amp;\overline{v_{3}}=\overline{v_{6}}\frac{\partial v_{6}}{\partial v_{3}}=\overline{v_{6}}\times1=1\\  
&amp;\overline{v_{2}}=\overline{v_{5}}\frac{\partial v_{5}}{\partial v_{2}}&#43;\overline{v_{4}}\frac{\partial v_{4}}{\partial v_{2}}=\overline{v_{5}}\times\cos v_{2}&#43;\overline{v_{4}}\times v_{1}\\  
&amp;\overline{v_{1}}=\overline{v_{4}} \frac{\partial v_{4}}{\partial v_{1}}&#43;\overline{v_{3}} \frac{\partial v_{3}}{\partial v_{1}}=\overline{v_{4}}\times v_{2}&#43; \overline{v_{3}} \frac{1}{v_{1}}=5&#43;\frac{1}{2}=5.5
\end{align*}

$$</div>
</li>
</ul>
<p>接下来我们讨论一下为什么前文中$\overline{v_2}$由两部分组成。考虑如下一个计算图：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406071612078.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>$y$可以被视作关于$v_2$和$v_3$的函数，即$y = f(v_2, v_3)$，那么：


<div>$$

\overline{v_{1}}=\frac{\partial y}{\partial v_{1}}=\frac{\partial f(v_{2},v_{3})}{\partial v_{2}}\frac{\partial v_{2}}{\partial v_{1}}&#43;\frac{\partial f(v_{2},v_{3})}{\partial v_{3}} \frac{\partial v_{3}}{\partial v_{1}}=\overline{v_{2}} \frac{\partial v_{2}}{\partial v_{1}}&#43;\overline{v_{3}} \frac{\partial v_{3}}{\partial v_{1}}

$$</div>

因此，定义partial adjoint $\overline{v_{i\to j}} = \overline{v_j} \frac{\partial v_j}{\partial v_i}$，那么$\overline{v_i}$可以表示为：


<div>$$

\overline{\nu_i}=\sum_{j\in next(i)}\overline{\nu_{i\rightarrow j}}

$$</div>
</p>
<h2 id="反向模式微分算法的实现">反向模式微分算法的实现</h2>
<p>基于以上分析，可以写出如下的实现反向模式微分算法的伪代码：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406071612188.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>其中<code>node_to_grad</code>是一个字典，保存着每个节点的partial adjoint值。由于是按照逆拓扑序列遍历的节点，因此可以保证当遍历到$i$时，所有以$i$为输入的节点（k节点所在的集合）都已被遍历完毕，即$\overline{v_k}$已经计算出来。</p>
<p>那么partial adjoint值使用什么数据结构保存呢？一个常见的思路是使用邻接矩阵，但是这个矩阵中有大量元素是不存在了，空间浪费很大。我们可以在原有计算图的基础上进行拓展来保存partial adjoint和adjonitzhi之间的计算关系。</p>
<p>如下图所示，黑色部分是原表达式的计算图，红色部分是将adjoint和partial adjount的计算图：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406071611419.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>使用计算图，除了能够节省内存外，还能清楚的看到正向计算的中间结果和反向计算之间的依赖关系，进而优化计算。</p>
<h2 id="反向模式ad和反向传播的区别">反向模式ad和反向传播的区别</h2>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406071817738.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>反向传播：</p>
<ul>
<li>在反向计算过程中使用与前向传播完全相同的计算图</li>
<li>应用于第一代深度学习框架</li>
</ul>
<p>反向AD：</p>
<ul>
<li>为adjoint在计算图中创建独立的节点</li>
<li>被应用于现代深度学习框架</li>
</ul>
<p>现代普遍应用反向AD的原因：</p>
<ul>
<li>某些损失函数是关于梯度的函数，这种情况下需要计算梯度的梯度，但反向传播就不能计算此类情况，而在反向AD中只要增加一个节点后在此计算梯度即可；</li>
<li>反向AD优化空间更大。</li>
</ul>
<h2 id="考虑tensor的反向模式ad">考虑Tensor的反向模式AD</h2>
<p>前面都是在假设中间变量是标量的基础上讨论的，接下来我们将其推广到Tensor上。</p>
<p>首先推广adjoint，定义对于一个Tensor$Z$，其adjoint$\overline{Z}$为：


<div>$$

=\begin{bmatrix}\frac{\partial y}{\partial Z_{1,1}}&amp;...&amp;\frac{\partial y}{\partial Z_{1,n}}\\...&amp;...&amp;...\\\frac{\partial y}{\partial Z_{m,1}}&amp;...&amp;\frac{\partial y}{\partial Z_{m,n}}\end{bmatrix}

$$</div>

鉴于


<div>$$

\begin{align*}Z_{ij}&amp;=\sum_kX_{ik}W_{kj}\\v&amp;=f(Z)\end{align*}

$$</div>

那么在计算$\overline{X_{i,k}}$时，需要将所有计算图上以$X_{i,k}$为输入的节点都找出来，即$Z$的第$i$行的每个元素。因此$\overline{X_{i,k}}$的计算公式为：


<div>$$

\overline{X_{i,k}}=\sum_{j}\frac{\partial Z_{i,j}}{\partial X_{i,k}}\bar{Z}_{i,j}=\sum_{j}W_{k,j}\bar{Z}_{i,j}

$$</div>

上述公式记为矩阵形式为：


<div>$$

\overline X = \overline Z W^T

$$</div>
</p>
<h1 id="lecture-5-automatic-differentiation-implementation">Lecture 5: Automatic Differentiation Implementation</h1>
<p>这讲主要介绍我们hw中要实现的needle的总体框架，项目中已给出了约1000行代码。</p>
<h2 id="autogradpy">autograd.py</h2>
<p>autograd保存与实现自动微分相关的代码。</p>
<p><code>Value</code>类对应计算图上的节点，其数据成员包括：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Value</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;A value in the computational graph.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1"># trace of computational graph</span>
</span></span><span class="line"><span class="cl">    <span class="n">op</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Op</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="n">inputs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">&#34;Value&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># The following fields are cached fields for</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># dynamic computation</span>
</span></span><span class="line"><span class="cl">    <span class="n">cached_data</span><span class="p">:</span> <span class="n">NDArray</span>
</span></span><span class="line"><span class="cl">    <span class="n">requires_grad</span><span class="p">:</span> <span class="nb">bool</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>op</code>用于保存该节点的运算符，<code>inputs</code>保存该运算符的操作数，<code>cached_data</code>保存该节点的数值，其数据结构因平台不同而区别。</p>
<h2 id="ops">ops</h2>
<p>本节主要介绍needle库的代码结构，笔记相当草率，建议看原视频。</p>
<p>ops文件夹（2023版本）或者op.py（2022）版本保存各种算子的实现。
<code>Op</code>类规定了两个必须要实现的接口：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Op</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Operator definition.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">NDArray</span><span class="p">]):</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;Calculate forward pass of operator.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Parameters
</span></span></span><span class="line"><span class="cl"><span class="s2">        ----------
</span></span></span><span class="line"><span class="cl"><span class="s2">        input: np.ndarray
</span></span></span><span class="line"><span class="cl"><span class="s2">            A list of input arrays to the function
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns
</span></span></span><span class="line"><span class="cl"><span class="s2">        -------
</span></span></span><span class="line"><span class="cl"><span class="s2">        output: nd.array
</span></span></span><span class="line"><span class="cl"><span class="s2">            Array output of the operation
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">gradient</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">:</span> <span class="s2">&#34;Value&#34;</span><span class="p">,</span> <span class="n">node</span><span class="p">:</span> <span class="s2">&#34;Value&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="s2">&#34;Value&#34;</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="s2">&#34;Value&#34;</span><span class="p">]]:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;&#34;&#34;Compute partial adjoint for each input value for a given output adjoint.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Parameters
</span></span></span><span class="line"><span class="cl"><span class="s2">        ----------
</span></span></span><span class="line"><span class="cl"><span class="s2">        out_grad: Value
</span></span></span><span class="line"><span class="cl"><span class="s2">            The adjoint wrt to the output value.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        node: Value
</span></span></span><span class="line"><span class="cl"><span class="s2">            The value node of forward evaluation.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">        Returns
</span></span></span><span class="line"><span class="cl"><span class="s2">        -------
</span></span></span><span class="line"><span class="cl"><span class="s2">        input_grads: Value or Tuple[Value]
</span></span></span><span class="line"><span class="cl"><span class="s2">            A list containing partial gradient adjoints to be propagated to
</span></span></span><span class="line"><span class="cl"><span class="s2">            each of the input node.
</span></span></span><span class="line"><span class="cl"><span class="s2">        &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>compute</code>接口用于描述该运算符实施的运算，<code>gradient</code>描述该运算符对应的梯度计算方式。</p>
<h1 id="lecture-6-fully-connected-network-optimization-initialization">Lecture 6: Fully connected network, optimization, initialization</h1>
<h2 id="全连接网络">全连接网络</h2>
<p>之前我们讨论的全连接网络都是不含偏执项的（为了方便进行手动微分），本章将介绍真正的MLP。其通过迭代的过程进行定义：


<div>$$

\begin{align*}  
&amp;z_{i&#43;1} = \sigma_i(W_i^Tz_i&#43;b_i), \ \ \ i=1,...,L\\  
&amp;h_\theta(x) = z_{L&#43;1}\\  
&amp;z_1 = x  
\end{align*}

$$</div>

上述模型中，可优化的参数集合为$\theta = {W_{1:L}, b_{1:L} }$。$\sigma_i(x)$是非线性的激活函数，特别的，最后一层没有激活函数，即$\sigma_L (x)= x$。</p>
<p>迭代的表达式写成矩阵形式为：


<div>$$

\begin{align*}  
Z_{i&#43;1} = \sigma_i(Z_iW_i&#43;1b_i^T)  
\end{align*}

$$</div>

其中，$1$表示一个表示一个全1的列向量，用于将列向量$b_i^T$广播到与矩阵$Z_iW_i$相匹配的形状。</p>
<p>在实际实现过程中，我们不用浪费空间去构造这样一个全1列向量，而是直接使用广播算子。在NumPy有许多自动的广播操作，但是在我们实现的needle库中，这一操作更加显式，例如对于$(n\times 1) \to (m \times n)$，要执行的操作为<code>A.reshape((1, n)).broadcast_to((m, n))</code>。</p>
<h2 id="优化">优化</h2>
<p>对于有监督的深度学习任务，一般的优化目标为：


<div>$$

\mathop{\text{minimize}}_{\theta} \ \ f(\theta) = \frac{1}{m}\sum_{i=1}^m{l(h_\theta(x^{(i)},y^{(i)}))}

$$</div>

接下来将介绍几常用的优化算法。</p>
<ul>
<li>梯度下降 gradient desecent
梯度下降法之前几节课讲过了，这里直接给出其数学表达式：


<div>$$

\theta_{t&#43;1} = \theta_t - \alpha \nabla_\theta f(\theta_t)

$$</div>

其中，$t$表示迭代次数。</li>
</ul>
<p>学习率这一参数对于该方法格外重要，不同的学习率的表现相差很大很大：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161006752.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>上图展示了大学习率和小学习率的迭代过程，如果目标函数再复杂一点，那么确定合适的学习率就会变得更加复杂。接下来将介绍一些不同的方法，它们各有其收敛行为。</p>
<p>对于梯度下降法的改进，有两种方案：梯度计算的变种和随机的变种。首先介绍第一类。</p>
<ul>
<li>牛顿法 Newton&rsquo;s Method
牛顿发使用二次曲面对一个高维函数做近似，因此其收敛速度显著快于一阶逼近的梯度下降法。其迭代公式为：


<div>$$

\theta_{t&#43;1} = \theta_t - \alpha(\nabla_\theta^2f(\theta_t))^{-1}\nabla_\theta f(\theta_t)

$$</div>

其中，$(\nabla_\theta^2f(\theta_t))^{-1}$是<em>Hessian</em>矩阵的逆矩阵。<em>Hessian</em>矩阵每个元素都是二阶导数，其具体定义为：


<div>$$

\nabla_\theta^2f(\theta_t) = H=\begin{bmatrix}\frac{\partial^2f}{\partial x_1^2}&amp;\frac{\partial^2f}{\partial x_1\partial x_2}&amp;\cdots&amp;\frac{\partial^2f}{\partial x_1\partial x_n}\\\frac{\partial^2f}{\partial x_2\partial x_1}&amp;\frac{\partial^2f}{\partial x_2^2}&amp;\cdots&amp;\frac{\partial^2f}{\partial x_2\partial x_n}\\\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\\frac{\partial^2f}{\partial x_n\partial x_1}&amp;\frac{\partial^2f}{\partial x_n\partial x_2}&amp;\cdots&amp;\frac{\partial^2f}{\partial x_n^2}\end{bmatrix}

$$</div>

对于二次函数，牛顿法可以一次给出指向最优点的方向</li>
</ul>
<p>这一方法广泛用于传统凸优化领域，但是很少用于深度学习优化。有两个主要原因：1) Hessian矩阵是$n\times n$的，因此参数量稍微大一点其计算代码都非常非常恐怖；2) 对于非凸优化，二阶方法是否更有效还有待商榷。</p>
<ul>
<li>动量梯度下降法 Momentum
在普通梯度下降法中，如果学习率太大，就会出现来回横跳的情况，如果对前几次梯度取平均，则可能改善这一情况。</li>
</ul>
<p>动量法正是对梯度取指数移动平均<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>的方案，具体来说有：


<div>$$

\begin{align*}  
&amp;u_{t&#43;1} = \beta u_t &#43;(1-\beta)\nabla_\theta f(\theta_t)\\  
&amp;\theta_{t&#43;1} = \theta_t - \alpha u_{t&#43;1}  
\end{align*}

$$</div>

该方法可视化过程如下图所示，在较大学习率的情况下，其相比梯度下降法优化曲线更为平滑。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161114979.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>无偏动量法 Unbiasing momentum
前一章节实际上有一个小瑕疵。如果$u_0$初始化为0，那么第一次进行更新是的梯度值是正常更新的$(1-\beta)$倍，因此其前期的收敛过程会稍慢，随着迭代的进行，其效应会逐渐减弱。</li>
</ul>
<p>为了修正其影响，我们可以在参数更新过程中对动量进行缩放，具体来说：


<div>$$

\theta_{t&#43;1} = \theta_{t} - \frac{\alpha u_{t&#43;1}}{1-\beta^{t&#43;1}}

$$</div>

如下图所示，修正以后其前期的更新速度要快了不少。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161128045.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>Nesterov momentum
Nesterov是梯度下降中一个非常有效的“trick”，其在传统momentum的基础上，将计算当前位置的梯度改为计算下一步位置的梯度。即：


<div>$$

u_{t&#43;1} = \beta u_t &#43;(1-\beta)\nabla_\theta f(\theta_t - \alpha u_t)

$$</div>

关于其为啥有效，看到了两篇文章。第一篇<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>通过推导认为该方案对二阶导数进行了近似，因此其收敛速度更快；第二篇<sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup>认为其能够更好地感知未来位置的梯度，在未来梯度很大时放慢步子。</li>
</ul>
<p>不看广告看疗效，对比普通Momentum，该方法的收敛速度要快得多。据说其也更适合一个深度网络。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161306619.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>Adam
Adam是一种自适应的梯度下降算法。不同参数其对应的梯度之间的大小差异可能很大，Adam对此的解决方案是提供一个缩放因子，梯度值小则将其缩放得大一点，即：


<div>$$

\begin{align*}  
&amp;u_{t&#43;1} = \beta_1 u_t &#43; (1-\beta_1)\nabla_\theta f(\theta_t)\\  
&amp;v_{t&#43;1} = \beta_2 v_t &#43; (1-\beta_2)(\nabla_\theta f(\theta_t))^2  &amp;\text{平方为逐元素运算}\\  
&amp;\theta_{t&#43;1} = \theta_t - \frac{\alpha u_{t&#43;1}}{\sqrt{v_{t&#43;1}}&#43;\epsilon} &amp; \text{所有元素均为逐元素运算}\\  
\end{align*}

$$</div>

Adam在实践中得到了广泛应用，在特定任务上，其可能不是最佳的优化器（如下图），但在大部分任务上，其都能有不错的可以作为基线的表现。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161602224.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<p>接下来将介绍随机变种。随机变种是在优化过程中加入了随机变量（噪声），例如每次使用数据集的一个子集对参数进行更新。</p>
<ul>
<li>随机梯度下降 Stochastic gradient descent
随机梯度下降正是每次使用数据集的一个子集对参数进行更新，即：


<div>$$

\theta_{t&#43;1} = \theta_t - \frac{\alpha}{|B|}\sum_{i\in B}\nabla_\theta l(h_\theta(x^{(i)},y^{i}))

$$</div>
</li>
</ul>
<p>看上去SGD的迭代次数比梯度下降要多得多，但是其每轮迭代的计算代价都要小的多，同时
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161624584.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>尽管在凸优化上可视化训练过程给了很直观的感受，但需要注意的是，深度学习并不是凸优化或者二次函数，这些优化方法在深度学习上的应用与在凸优化上的效果可能完全不同。</p>
<h2 id="初始化">初始化</h2>
<p>参数的初始值如何确定？这是个好问题。</p>
<p>在凸优化中，尝尝将所有参数初始化为0，如果在神经网络中也这么做，那么每一层的输出都是0，求得的梯度也都是0🙁。全0是这个模型的一个不动点，模型将永远得不到更新。</p>
<ul>
<li>
<p>初始化参数对梯度的影响很大
一种自然的想法是对参数进行随机初始化，例如按照多元正态分布进行初始化。但是，分布中参数的选择对于梯度的影响可能会相当大，如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406161659652.png?x-oss-process=image/quality,q_90/format,webp">
随着层数的增加，如果激活值范数变化的太剧烈，会导致梯度爆炸或者消失问题，如果梯度值过大或者过小，也会导致这些问题。</p>
</li>
<li>
<p>权重的在训练过程的变化可能很小
可能存在这样一个误区：无论初始值如何选择，这些参数最终都会收敛到某个区域附近。事实并非如此，整个训练过程中权重的变化并非如此剧烈。</p>
</li>
<li>
<p>为什么2/n在前面是个合适的初始化参数
这里直接使用gpt对这页ppt的解释</p>
</li>
</ul>
<blockquote>
<p>考虑独立的随机变量 𝑥∼𝑁(0,1)x∼N(0,1) 和 𝑤∼𝑁(0,1𝑛)w∼N(0,n1​)，其中 𝑥x 是输入，𝑤w 是权重。</p>
<h4 id="期望和方差">期望和方差</h4>
<ul>
<li>𝐸[𝑥⋅𝑤𝑖]=0E[x⋅wi​]=0</li>
<li>Var[𝑥⋅𝑤𝑖]=1𝑛Var[x⋅wi​]=n1​</li>
</ul>
<p>因此，对于 𝑤𝑇𝑥wTx：</p>
<ul>
<li>𝐸[𝑤𝑇𝑥]=0E[wTx]=0</li>
<li>Var[𝑤𝑇𝑥]=1Var[wTx]=1（根据中心极限定理，𝑤𝑇𝑥wTx 服从 𝑁(0,1)N(0,1)）</li>
</ul>
<h3 id="激活值的方差">激活值的方差</h3>
<p>如果使用线性激活函数，并且 𝑧𝑖∼𝑁(0,𝐼)zi​∼N(0,I)，则 𝑊𝑖∼𝑁(0,1𝑛𝐼)Wi​∼N(0,n1​I)，那么：</p>
<p>𝑧𝑖+1=𝑊𝑖𝑧𝑖zi+1​=Wi​zi​</p>
<h3 id="relu-非线性">ReLU 非线性</h3>
<p>如果使用 ReLU 非线性激活函数，由于 ReLU 会将一半的 𝑧𝑖zi​ 分量设为零，因此为了达到相同的最终方差，需要将 𝑊𝑖Wi​ 的方差增加一倍。因此：</p>
<p>𝑊𝑖∼𝑁(0,2𝑛𝐼)Wi​∼N(0,n2​I)</p>
<p>这就是所谓的 Kaiming 正态初始化（He 初始化），它特别适用于 ReLU 激活函数。</p>
</blockquote>
<h1 id="lecture-7-neural-network-library-abstractions">Lecture 7: Neural Network Library Abstractions</h1>
<p>这节课主要介绍如何使用我们的needle库来实现一些简单的深度学习模型，构造一些小组件。</p>
<h2 id="程序抽象">程序抽象</h2>
<p>现代成熟的深度学习库提供了一些API，站在今天的视角，这些API都是都是恰到好处的。通过思考为什么要这样设计接口，可以让我们更好地理解深度学习库在进行程序抽象时的内部逻辑。</p>
<p>首先几个经典的深度学习框架进行分析，包括Caffe、TensorFlow和PyTorch。</p>
<ul>
<li>Caffe 1.0 （2014）
在Caffe中，使用Layer这一概念来表示神经网络中的一个个小模块，通过拼接和替换Layer，可以实现快速构造和修改神经网络，并使用同一套代码进行训练。</li>
</ul>
<p>Layer类提供了<code>forward</code>和<code>backward</code>两个接口：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Layer</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">bottom</span><span class="p">,</span> <span class="n">top</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="k">pass</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">top</span><span class="p">,</span> <span class="n">propagate_down</span><span class="p">,</span> <span class="n">bottom</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">		<span class="k">pass</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>forward</code>负责将来自bottom的数据进行前向传播，然后将数据保存到top中。在<code>backward</code>接口中，top保存来自输出的梯度，<code>propagate_down</code>用以指示是否要对其求梯度，bottom用于存放梯度。</p>
<p>在Caffe中，计算梯度是“就地”完成的，而非在计算图上新增额外的节点。作为第一代深度学习框架，直接计算梯度的思想是朴素但是符合直觉的。</p>
<ul>
<li>TensorFlow 1.0 （2015）
作为第二代深度学习框架，其在引入了计算图的概念。在计算图中，只要定义前向计算的计算方式，当需要计算梯度时，直接对计算图进行拓展即可。一个简短实例为：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">v1</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">v2</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">v1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">v3</span> <span class="o">=</span> <span class="n">v2</span> <span class="o">+</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl"><span class="n">v4</span> <span class="o">=</span> <span class="n">v2</span> <span class="o">*</span> <span class="n">v3</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">sess</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Session</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">value4</span> <span class="o">=</span> <span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">v4</span><span class="p">,</span> <span class="n">feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">v1</span><span class="p">:</span> <span class="n">numpy</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">])})</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>以上代码<code>v1~4</code>仅仅是占位符，用于构建计算图，在没有输入传入前并没有值。通过会话来获取某个输入的情况下输出的值。</p>
<p>上述过程被称为声明式编程。即计算图在定义时并不会立即执行，而是等到会话（session）运行时才执行。这种方式的优点有：代码分区，可读性高；运行前计算图已知，可以针对性优化；通过会话便于实现分布式计算</p>
<ul>
<li>PyTorch (needle)
PyTorch使用的是命令式编程，相比声明式编程，命令式编程在构建计算图时就已经指定其值。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">needle</span> <span class="k">as</span> <span class="nn">ndl</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">v1</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl"><span class="n">v2</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">v1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">v3</span> <span class="o">=</span> <span class="n">v2</span> <span class="o">+</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl"><span class="n">v4</span> <span class="o">=</span> <span class="n">v2</span> <span class="o">*</span> <span class="n">v3</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>命令式编程可以很方便地与Python原生控制流语句结合在一起，例如：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">if</span> <span class="n">v4</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">v5</span> <span class="o">=</span> <span class="n">v4</span> <span class="o">*</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl"><span class="k">else</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">v5</span> <span class="o">=</span> <span class="n">v4</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>tf1.0的效率更高，适合推理和部署。PyTorch1.0则更适合开发和debug。</p>
<h2 id="高级模块化库组件">高级模块化库组件</h2>
<p>如何使用深度学习库来实现深度学习呢？在hw1中我们使用一个个底层算子来搭建模型和实现训练过程，但这样开发太低效了。深度学习本身是很模块化的：由模型、损失函数和优化方法三部分组成。不但如此，模型本身也是高度模块化的。因此，我们在实现深度学习库时，必须精心设计好接口，以便支持该模块化的特性。</p>
<p>在PyTorch中，有一类叫做<code>nn.Module</code>，对应的就是模型中一个个小的子模块，其特点是以Tensor同时作为输入和输出。损失函数也满足这一特性，其可以被视为一个模块。</p>
<p>对于优化器，其作用是输入一个模型，对该模型中的参数按照某一规则进行更新。</p>
<p>为了防止过拟合，有些模型还具有正则项，其有两种实现方式：</p>
<ul>
<li>作为损失函数的一部分进行实现</li>
<li>直接整合进优化器中</li>
</ul>
<p>参数初始化同样很重要，其一般在构建<code>nn.Module</code>中指定。</p>
<p>数据加载也是一个很重要的模块。数据加载中还经常对数据进行预处理和增强。</p>
<p>各组件之间数据流图如下所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202406200916559.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h1 id="lecture-8-neural-network-implementation">Lecture 8: Neural Network Implementation</h1>
<h2 id="修改tensor的data域">修改Tensor的data域</h2>
<p>在实现SGD时，由于存在多个batch，可能会在一个循环里对待学习参数进行更新，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">iterations</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">	<span class="n">w</span> <span class="o">-=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">grad</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>正如在<a href="/notes/notes-on-cmu-10-414-assignments/#sgd-for-a-two-layer-neural-network">CMU 10-414 Assignments 实验笔记 &gt; SGD for a two-layer neural network</a>踩过的坑那样，直接使用Tensor之间的算子进行参数更新会导致每次更新都会在计算图上增加一个新的节点w，这个节点具有Op和inputs，严重拖累反向传播速度。</p>
<p>为了避免每次更新参数时都在计算图上留下一个需要求梯度的节点，needle库提供了<code>Tensor.data()</code>方法，用于创建一个与<code>Tensor</code>共享同一个底层data的节点，但其不存在Op和inputs，也不用对其进行求导。</p>
<p>因此，可以使用<code>Tensor.data</code>方法，在不干扰计算图反向传播的前提下对参数进行正常的更新，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">w</span><span class="o">.</span><span class="n">data</span> <span class="o">-=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">grad</span><span class="o">.</span><span class="n">data</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="数值稳定性">数值稳定性</h2>
<p>每个数值在内存中的存储空间都是有限的，因此保存的数值的范围和精度都是有限的，计算过程中难免出现溢出或者精度丢失的情况，在实现算子时，必须考虑到数值稳定性的问题。</p>
<p>例如，在softmax公式中，由于指数运算的存在，数值很有可能就上溢了，一个修正方式是在进行softmax运算前，每个元素都减去输入的最大值，以防止上溢。即：


<div>$$

z_i = \text{softmax}(x_i) = \frac{\exp(x_i -c)}{\sum_k {\exp(x_k-c)}}

$$</div>

其中，$c = \max(x)$。</p>
<p>类似的，其它算子也要考虑相应的稳定性问题。</p>
<h2 id="parameter-类">Parameter 类</h2>
<p><code>Parameter</code>类用于表示可学习的参数，其是<code>Tensor</code>的子类。相比<code>Tensor</code>类，这个类不必再引入新的行为或者接口，因此其实现很简单：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Parameter</span><span class="p">(</span><span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;parameter&#34;&#34;&#34;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="module-类">Module 类</h2>
<p><code>Module</code>类用于表示神经网络中一个个子模块。其具有如下接口：</p>
<ul>
<li><code>parameters</code>：获取模块中所有可学习参数</li>
<li><code>__call__</code>：进行前向传播
在实现时，定义了一个辅助函数<code>_get_params</code>用于提取一个模块中的所有可学习参数。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">_get_params</span><span class="p">(</span><span class="n">value</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">Parameter</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="p">[</span><span class="n">value</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">params</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">value</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
</span></span><span class="line"><span class="cl">            <span class="n">params</span> <span class="o">+=</span> <span class="n">_get_params</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">params</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">value</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Module</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">_get_params</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="optimizer-类">Optimizer 类</h3>
<p><code>Optimizer</code>类用于优化模型中可学习参数，其有两个关键接口：</p>
<ul>
<li><code>reset_grad</code>：重置模型中可学习参数的grad字段</li>
<li><code>step</code>：更新参数值
<code>reset_grad</code>实现比较简单，<code>step</code>方法则依赖于优化算法的具体实现：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Optimizer</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">params</span> <span class="o">=</span> <span class="n">params</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">reset_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">raise</span> <span class="bp">NotImplemented</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lecture-9-normalization-and-regularization">Lecture 9: Normalization and Regularization</h1>
<h2 id="normalization">Normalization</h2>
<p>在前面几讲提到过，参数初始值的选择对于模型的训练很重要，不恰当的初始值参数会导致梯度消失或者爆炸💥。更重要的是，当训练完成后，这些梯度和参数值大小仍有初始值差不多，这更强调了初始值的重要性。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407051309612.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>为了修复这一问题，引入了layer normalization。其思想就是对激活层的输出进行标准化，即将输出减去期望后除以标准差：


<div>$$

\begin{align*}  
\hat{z}_{i&#43;1} &amp;= \sigma_i (W_i^Tz_i&#43;b_i)\\  
z_{i&#43;1} &amp;=\frac{\hat{z}_{i&#43;1} - E(\hat{z}_{i&#43;1})}{Var(\hat{z}_{i&#43;1})&#43;\epsilon}  
\end{align*}

$$</div>

上述技巧目前已经得到广泛应用，但在实践中，应用layer norm会导致模型难以收敛到一个很小的loss值。</p>
<p>另外一种技巧是batch norm。layer norm是对每一个sample（z的每一行）做归一化，而batch norm对每一列归一化。这一方法使得每个batch的所有样本都会对该batch中某个样本的推理结果有影响，因此在进行推理时，batch norm中的归一化的参数应该使用整个训练集上的参数，而非推理时输入样本的batch参数。</p>
<h2 id="regularization">Regularization</h2>
<p>正则化用于对抗过拟合，所谓过拟合是指模型在训练集上性能非常好，但在测试机上泛化性能很差。正则化就是限制参数复杂度的过程，可以分为显式正则和隐式正则。</p>
<p>隐式正则化是指现有算法或架构在不显式添加正则化项的情况下，自然地对函数类进行限制。具体来说，隐式正则化通过以下方式实现：</p>
<ul>
<li><strong>算法的固有特性</strong>：例如，随机梯度下降（SGD）等优化算法在训练过程中自带某些正则化效果。虽然我们并没有显式地优化所有可能的神经网络，而是通过SGD优化那些在特定权重初始化下的神经网络。这种优化过程本身对模型的复杂度进行了限制。</li>
<li><strong>架构的设计</strong>：某些网络架构设计本身就具有正则化效果。例如，卷积神经网络（CNN）的共享权重机制和局部连接特性，自然地减少了模型参数的数量，从而降低了模型复杂度。</li>
</ul>
<p>显式正则化指的是通过显式得修改模型使其能够避免对训练集过拟合。</p>
<p>一种最常见的应用于参数的正则化方案是l2正则化，即l2 regularization a.k.a weight decay。传统认为，模型参数值的大小可以在一定程度上指示出模型的复杂度，因此通过在优化目标中引入l2正则项来控制模型的大小。一般地，引入l2 regularization的机器学习优化问题可以表示为：</p>
<p>

<div>$$

\mathrm{minimize} \quad \frac{1}{m}\sum_i^m{l(h_{w_{1:L}}(x^{(i)}, y^{(i)}))}&#43;\frac{\lambda}{2}\sum_{i=1}^L{||w_i||_F^2}

$$</div>

其中，$||w_i||_F$是Frobenius范数，其表示矩阵每个元素的平方和的平方根。</p>
<p>得益于这里的系数是$1/2$，在对$w_i$求导时正则项恰好为$\lambda w_i$。梯度更新的公式相应变为：


<div>$$

W_i :=(1-\alpha \lambda)W_i-\alpha \nabla \frac{1}{m}l

$$</div>
</p>
<p>注意，引入l2正则化后，每轮迭代都会将参数缩小至原来的$1-\alpha \lambda$。很多地方不将l2正则化作为损失函数的一部分，而是将其作为优化器的一部分，即直接将参数进行缩小，这种方法被称为weight decay，显然二者是等价的。</p>
<p>另外一种正则化方法是dropout，其思想是在训练过程中随机地将一些激活层的输出置为0，并对其它输出放大，以确保整层输出的数学期望不变，形式化表示为：


<div>$$

\begin{align*}  
\hat{z}_{i&#43;1} &amp;= \sigma_i(W^T_i z_i)&#43;b_i\\  
(z_{i&#43;1})_j &amp;=  
\begin{cases}  
((\hat{z}_{i&#43;1} )_j)/(1-p) \quad &amp;\text{以概率}1-p\\  
0 &amp;\text{以概率}p  
\end{cases}  
\end{align*}

$$</div>

在推理时，则不需要进行dropout。</p>
<p>直观地说，dropout能够提升模型在激活层部分缺失时进行推理的能力，但显然这一能力没什么卵用。另一种解释是dropout提升了模型训练过程中的随机性，类似SGD。</p>
<h1 id="lecture-10-convolutional-networks">Lecture 10: Convolutional Networks</h1>
<h2 id="convolutional-operators-in-deep-networks">Convolutional operators in deep networks</h2>
<p>在hw2中，我们通过flatten操作将图片视作一个序列进行计算，这对于小尺寸的图片是可行的，但对于大尺寸的图片，例如256×256的图片，将会导致输入异常庞大，网络也随之变大。这种简单粗暴的处理方式不利于提取图片的内在特征，例如，如果对图片进行平移，其输入序列的变化相当大。</p>
<p>卷积网络出于以下两个动机：</p>
<ul>
<li>层之间的激活以局部的方式发生，并且隐藏层的输出也被视为图像</li>
<li>在所有的空间位置共享权重</li>
</ul>
<p>卷积网络有以下两个优点：</p>
<ul>
<li>使用的参数很少。参数量由卷积网络的大小决定，而和输入的shape无关；</li>
<li>能够很好地捕获图片的内在不变形。</li>
</ul>
<p>卷积的计算示意如下图所示，卷积核在原图上滑动，从而产生一张新的图片。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407250959153.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>在深度学习中，输入和隐藏层都很少是一个1D的矩阵，一般而言，其是由多个通道的。例如，一张彩色图片由RGB三通道组成，而中间的隐藏层，通常会有比较大的通道数，如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407251015471.png?x-oss-process=image/quality,q_90/format,webp">
记卷积层的输入$x\in \mathbb{R}^{h\times w \times c_{in}}$，输出$z\in \mathbb{R}^{h\times w \times c_{out}}$。从上图可以发现，卷积输出的某个通道，都是由输入在同一个局部的所有通道共同决定的，因此，卷积核$W\in \mathbb{R}^{c_{in}\times c_{out}\times k \times k}$，卷积过程可以形式化表示为：


<div>$$

z[:,:,s] = \sum_{r=1}^{c_{in}}x[:,:,r] \cdot W[r,s,:,:]

$$</div>

关于多通道卷积，另外一种更符合直觉的理解是将相同位置的各通道的组合看作是一个向量，即下图中，$x$每一格都是一个向量，$W$每一格都是$c_{out} \times c_{in}$的矩阵，卷积的输出由对应位置的$z$和$W$按矩阵乘法并求和得到。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407251027480.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="elements-of-practical-convolutions">Elements of practical convolutions</h2>
<p>在实际的卷积操作中，通常还会应用一些别的技术。</p>
<ul>
<li>
<p>Padding
原始的卷积操作，会将输出的长宽变小$k-1$个长度，通过在周围填充$(k-1)/2$个0元可以保证输出的shape与输入一致。为了避免两侧填充不一致这个别扭的情况，我们一般选取卷积核大小为奇数。</p>
</li>
<li>
<p>Strided Convolutions / Pooling
经过padding之后的卷积操作，不改变图片的shape，但在实际应用中，通常会对图片进行下采样。用两种解决方案：</p>
</li>
</ul>
<ol>
<li>使用最大/平均池化来聚合信息，例如，使用一个2×2的核进行池化操作，每次移动的步长为2，就可以将整张图片长宽各放缩至原来一半；</li>
<li>卷积操作时，卷积核移动的步长大于1。</li>
</ol>
<ul>
<li>
<p>Grouped Convolutions
当输入和输出的通道数很大时，卷积核的参数量仍可能非常非常大。一种解决方案是，使用分组卷积，即将输入通道分为多个组，每个组独立进行卷积操作，如下图所示。如果分为G组，则参数量可减少为原来的1/G。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407251311275.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>Dilations
传统卷积的感受野和卷积核一样大，扩张卷积的思路是在卷积区域中插入间隔，能够扩大卷积核的感受野。下图表示的很形象。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407251316286.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
</ul>
<h2 id="differentiating-convolutions">Differentiating convolutions</h2>
<p>正如前文所提到的，我们可以通过一系列矩阵向量乘法和求和运算来实现卷积操作，但这么做效率太低了，我们的计算图上有很多中间节点，这些中间变量将消耗大量的内存空间。因此，我们不应该使用微分库中的算子来计算卷子，而是将其作为一个算子来实现，并手动计算其微分。</p>
<p>首先定义卷积操作：


<div>$$

z = \operatorname{conv}(x,W)

$$</div>

$z$的梯度怎么与adjoints乘呢？这是个问题。$z$的梯度有以下二者：$\frac{\partial z}{\partial x}$和$\frac{\partial z}{\partial W}$，从形式上看，他们是3阶张量初以四阶张量，相当复杂。</p>
<p>首先考虑最简单的矩阵和向量相乘的情况，即：


<div>$$

z = Wx

$$</div>

那么$z$对$x$的导数就是$W$，即其与adjoint的乘法计算公式为：


<div>$$

W^T\bar{v}

$$</div>

也就是说如果在前向传播中我们计算一个矩阵和向量的乘积，那么在反向传播中，我们要计算这个矩阵的转置和adjoint的乘积。那对于卷积来说，它的“转置”是什么呢？</p>
<ul>
<li>将卷积视为矩阵运算I
以1d卷积为例，我们考虑如下的一个卷积运算，其中每个格子都是一个向量或者矩阵。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407251428228.png?x-oss-process=image/quality,q_90/format,webp">
将上面这个矩阵运算展开，可以得到：


<div>$$

\begin{bmatrix}z_1\\z_2\\z_3\\z_4\\z_5\end{bmatrix}=x*w=\begin{bmatrix}w_2&amp;w_3&amp;0&amp;0&amp;0\\w_1&amp;w_2&amp;w_3&amp;0&amp;0\\0&amp;w_1&amp;w_2&amp;w_3&amp;0\\0&amp;0&amp;w_1&amp;w_2&amp;w_3\\0&amp;0&amp;0&amp;w_1&amp;w_2\end{bmatrix}\begin{bmatrix}x_1\\x_2\\x_3\\x_4\\x_5\end{bmatrix}

$$</div>

有了$\hat{W}$，我们可以很容易地写出$\hat{W}^T$,即：


<div>$$

\hat W^T=\begin{bmatrix}w_2&amp;w_1&amp;0&amp;0&amp;0\\w_3&amp;w_2&amp;w_1&amp;0&amp;0\\0&amp;w_3&amp;w_2&amp;w_1&amp;0\\0&amp;0&amp;w_3&amp;w_2&amp;w_1\\0&amp;0&amp;0&amp;w_3&amp;w_2\end{bmatrix}

$$</div>

不难发现，这个算子实际上是$[w_3, w_2, w_1]$这个卷积核，即原始卷积核翻转后的卷积核。也就是说，梯度和adjoint的乘积可以表示为：


<div>$$

\hat{v}\frac{\partial \operatorname{conv}(x,w)}{\partial x} = \operatorname{conv}(\hat{v},\operatorname{flip}(w))

$$</div>
</li>
<li>将卷积视为矩阵运算II
接下来我们考虑卷积对于参数$w$的导数。同样，我们将矩阵运算展开，可以得到：


<div>$$

\begin{bmatrix}z_1\\z_2\\z_3\\z_4\\z_5\end{bmatrix}=x*w=\begin{bmatrix}0&amp;x_1&amp;x_2\\x_1&amp;x_2&amp;x_3\\x_2&amp;x_3&amp;x_4\\x_3&amp;x_4&amp;x_5\\x_4&amp;x_5&amp;0\end{bmatrix}\begin{bmatrix}w_1\\w_2\\w_3\end{bmatrix}

$$</div>

相比矩阵运算I，我们构造出的$\hat{X}$矩阵是一个密集矩阵，在实现卷积算子时，我们常常采用这个方案来运算。这个$\hat{X}$矩阵被称为“im2col”矩阵（image to column）。</li>
</ul>
<h1 id="lecture-11-hardware-acceleration">Lecture 11: Hardware acceleration</h1>
<h2 id="general-acceleration-techniques">General acceleration techniques</h2>
<p>现代机器学习框架可以视为两层：上层是计算图，用于前向推理、自动微分和反向传播；下层是张量线性代数库，其负责底层的张量计算。在needle中，我们目前使用numpy作为线性代数库。本节我们将介绍一些常见的加速技术。</p>
<ul>
<li>Vectorization 向量化
如果我们要将两个256长度的array相加，一种标量的处理方式是256个元素逐个相加，但是很多硬件都提供了批量从内存读取、向量运算指令，即优化为如下代码：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">vecadd</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="mi">64</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">float4</span> <span class="n">a</span> <span class="o">=</span> <span class="nf">load_float4</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="mi">4</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">		<span class="n">float4</span> <span class="n">b</span> <span class="o">=</span> <span class="nf">load_float4</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="mi">4</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">		<span class="n">float4</span> <span class="n">c</span> <span class="o">=</span> <span class="nf">add_float4</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">		<span class="nf">store_float4</span><span class="p">(</span><span class="n">C</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">c</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>这里要求ABC所在的内存块要是按照128 bit对齐的。</p>
<ul>
<li>Data layout &amp; strides 数据布局&amp;步幅
在内存中，数据是线性排列的，因此一个矩阵在内存中有两种布局方式：行优先和列优先。一些古老的语言使用列优先，现代的语言偏向使用行优先。</li>
</ul>
<p>在许多库中，还引入了一种stride格式布局，即在保存张量时，额外保存一个数据，用于标识每个维度上需要移动的步长。在这种情况下，<code>a[i, j] = a_data[i * strides[0] + j * strides[1]]</code></p>
<p>这个方案可以在不用复制数据的情况下实现很多操作：通过改变offset和shape来实现切片；通过交换strides来实现转置；通过插入等于0的stride来实现广播。</p>
<p>其缺点是访存操作可能不再连续，因此向量化技术不可用，很多库也需要先把他们拼接之后再使用。</p>
<ul>
<li>Parallelization 并行化
使用openmp可以将计算分配给多个核并行处理：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">vecadd</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="cp">#pragma omp parallel for
</span></span></span><span class="line"><span class="cl">	<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="mi">64</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">float4</span> <span class="n">a</span> <span class="o">=</span> <span class="nf">load_float4</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="mi">4</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">		<span class="n">float4</span> <span class="n">b</span> <span class="o">=</span> <span class="nf">load_float4</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="mi">4</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">		<span class="n">float4</span> <span class="n">c</span> <span class="o">=</span> <span class="nf">add_float4</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">		<span class="nf">store_float4</span><span class="p">(</span><span class="n">C</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">c</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="case-study-matrix-multiplication">Case study: matrix multiplication</h2>
<p>本节我们将讨论如何优化矩阵乘法。</p>
<ul>
<li>Vanilla matrix multiplication 朴素矩阵乘法
最朴素的想法是使用三重循环完成，其复杂度是$O(n^3)$，即如下代码：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">n</span><span class="p">][</span><span class="n">n</span><span class="p">],</span> <span class="n">B</span><span class="p">[</span><span class="n">n</span><span class="p">][</span><span class="n">n</span><span class="p">],</span> <span class="n">C</span><span class="p">[</span><span class="n">n</span><span class="p">][</span><span class="n">n</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">c</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">k</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">c</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">		<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在现代存储器中，L1 cache的速度比DRAM快200倍，通过优化数据的读取就可以显著提升计算速度，考虑到这一点，我们可以将中间变量保存到寄存器中，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">n</span><span class="p">][</span><span class="n">n</span><span class="p">],</span> <span class="n">B</span><span class="p">[</span><span class="n">n</span><span class="p">][</span><span class="n">n</span><span class="p">],</span> <span class="n">C</span><span class="p">[</span><span class="n">n</span><span class="p">][</span><span class="n">n</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="k">register</span> <span class="kt">float</span> <span class="n">c</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span><span class="o">&lt;</span><span class="n">n</span><span class="p">;</span> <span class="n">k</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="k">register</span> <span class="kt">float</span> <span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">		<span class="k">register</span> <span class="kt">float</span> <span class="n">b</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">		<span class="n">c</span> <span class="o">+=</span> <span class="n">a</span><span class="o">*</span><span class="n">b</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="p">}</span>
</span></span><span class="line"><span class="cl">		<span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">c</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码中，从读取A、B到寄存器的操作分别进行了$n^3$次，需要3个寄存器来完成该操作。</p>
<ul>
<li>Register tiled matrix multiplication 寄存器分块矩阵乘法
该方案的思路是将结果进行分块，每次计算其中的一块，即：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">v1</span><span class="p">][</span><span class="n">n</span><span class="o">/</span><span class="n">v3</span><span class="p">][</span><span class="n">v1</span><span class="p">][</span><span class="n">v3</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">B</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">v2</span><span class="p">][</span><span class="n">n</span><span class="o">/</span><span class="n">v3</span><span class="p">][</span><span class="n">v2</span><span class="p">][</span><span class="n">v3</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">C</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">v1</span><span class="p">][</span><span class="n">n</span><span class="o">/</span><span class="n">v2</span><span class="p">][</span><span class="n">v1</span><span class="p">][</span><span class="n">v2</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">v1</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">v2</span><span class="p">;</span> <span class="o">++</span><span class="n">j</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">register</span> <span class="kt">float</span> <span class="n">c</span><span class="p">[</span><span class="n">v1</span><span class="p">][</span><span class="n">v2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">v3</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">register</span> <span class="kt">float</span> <span class="n">a</span><span class="p">[</span><span class="n">v1</span><span class="p">][</span><span class="n">v3</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="k">register</span> <span class="kt">float</span> <span class="n">b</span><span class="p">[</span><span class="n">v2</span><span class="p">][</span><span class="n">v3</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="n">k</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="n">c</span> <span class="o">+=</span> <span class="nf">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">T</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">c</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码中，要计算的矩阵C被分为$v_1\times v_2$的小矩阵，为了计算出每一块，每次必须从A中选出$v_1$行，从B中选出$v_2$列，这两组子矩阵可以按照长度$v_3$再次划分。在计算中，前两个循环依次遍历C中的一小块，然后初始化$v_1 \times v_2$个寄存器用于保存该块内容，然后再根据$v_3$的大小二次划分，进行矩阵运算，将这些结果加到对应的寄存器上，第三个循环结束后就计算出C的一个子块。</p>
<p>A的数据加载开销是$n^3/v_2$，B的数据加载开销是$n^3/v_1$，A的寄存器开销是$v_1 \times v_3$，B的寄存器开销是$v_2\times v_3$，C的寄存器开销是$v_1\times v_2$。注意到$v_3$不影响数据加载的开销，因此可以取$v_3$为1，然后在满足寄存器总数约束的情况下，最大化$v_1$和$v_2$。</p>
<p>之所以能够减小开销是因为在矩阵计算中，元素被重复使用，通过每次计算一个分块的方式，可以保证这个分块内用到的重复数据只要加载一次。</p>
<ul>
<li>Cache line aware tiling 缓存行感知分块
前面我们使用寄存器来进行加速，本节我们考虑使用cache来加速。我们的实现代码为：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">b1</span><span class="p">][</span><span class="n">b1</span><span class="p">][</span><span class="n">n</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">B</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">b2</span><span class="p">][</span><span class="n">b2</span><span class="p">][</span><span class="n">n</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">C</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">b1</span><span class="p">][</span><span class="n">n</span><span class="o">/</span><span class="n">b2</span><span class="p">][</span><span class="n">b1</span><span class="p">][</span><span class="n">b2</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">b1</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">l1cache</span> <span class="kt">float</span> <span class="n">a</span><span class="p">[</span><span class="n">b1</span><span class="p">][</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">b2</span><span class="p">;</span> <span class="o">++</span><span class="n">j</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">l1cache</span> <span class="kt">float</span> <span class="n">b</span><span class="p">[</span><span class="n">b2</span><span class="p">][</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="nf">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">T</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码中，结果矩阵C被分块为$b_1 \times b_2$，A和B分别按行和按列分块，通过两层循环遍历计算C中的每个子块，计算子块的过程可以使用寄存器分块进行加速。</p>
<p>上述代码中，A的加载开销是$n^2$，B的加载开销是$n^3/b1$。有两个约束，一个是$b_1n+b_2n &lt; \text{l1 chche size}$，另一个是$b_1 % v_1=b_2 % v_2 = 0$。</p>
<ul>
<li>Put it together
将缓存版本的<code>dot</code>运算使用寄存器版本展开，可以得到最终的分块乘法实现：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">b1</span><span class="p">][</span><span class="n">b1</span><span class="o">/</span><span class="n">v1</span><span class="p">][</span><span class="n">n</span><span class="p">][</span><span class="n">v1</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="n">dram</span> <span class="kt">float</span> <span class="n">B</span><span class="p">[</span><span class="n">n</span><span class="o">/</span><span class="n">b2</span><span class="p">][</span><span class="n">b2</span><span class="o">/</span><span class="n">v2</span><span class="p">][</span><span class="n">n</span><span class="p">][</span><span class="n">v2</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">b1</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">l1cache</span> <span class="kt">float</span> <span class="n">a</span><span class="p">[</span><span class="n">b1</span><span class="o">/</span><span class="n">v1</span><span class="p">][</span><span class="n">n</span><span class="p">][</span><span class="n">v1</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">/</span><span class="n">b2</span><span class="p">;</span> <span class="o">++</span><span class="n">j</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">l1cache</span> <span class="n">b</span><span class="p">[</span><span class="n">b2</span><span class="o">/</span><span class="n">v2</span><span class="p">][</span><span class="n">n</span><span class="p">][</span><span class="n">v2</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">j</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">x</span> <span class="o">&lt;</span> <span class="n">b1</span><span class="o">/</span><span class="n">v1</span><span class="p">;</span> <span class="o">++</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">y</span> <span class="o">&lt;</span> <span class="n">b2</span><span class="o">/</span><span class="n">v2</span><span class="p">;</span> <span class="o">++</span><span class="n">y</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">register</span> <span class="kt">float</span> <span class="n">c</span><span class="p">[</span><span class="n">v1</span><span class="p">][</span><span class="n">v2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="k">register</span> <span class="kt">float</span> <span class="n">ar</span><span class="p">[</span><span class="n">v1</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="o">:</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                    <span class="k">register</span> <span class="kt">float</span> <span class="n">br</span><span class="p">[</span><span class="n">v2</span><span class="p">]</span> <span class="o">=</span> <span class="n">b</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="o">:</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                    <span class="n">C</span> <span class="o">+=</span> <span class="nf">dot</span><span class="p">(</span><span class="n">ar</span><span class="p">,</span> <span class="n">br</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码的数据加载开销是：


<div>$$

speed_{l1}\cdot(\frac{n^3}{v_2}&#43;\frac{n^3}{v1})&#43;speed_{dram}\cdot(n^2&#43;\frac{n^3}{b_1})

$$</div>
</p>
<h1 id="lecture-12-gpu-acceleration">Lecture 12: GPU acceleration</h1>
<h2 id="gpu-programming">GPU programming</h2>
<p>如下图所示，CPU是一种通用处理器，其可以灵活地处理不同的任务，每个核都有独立的控制器。但在某些任务，例如图形渲染中，可能存在大量的重复工作，例如给每个像素都进行相同的处理。GPU正是擅长处理此类任务，其有大量的执行单元，可以批量执行同一指令。将GPU应用于深度学习，可以带来10X ~ 100X的加速倍率。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407260953795.png?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>GPU programming model: SIMT
在本章节，我们将使用CUDA中的术语，但是在别的模型中，通常也有对应的概念。</li>
</ul>
<p>SIMT中所有的线程都执行相同的指令，但是具有不同的数据通路。线程被分组为block，每个block共享内存。block被分组为launch grid，当启动一个kernel时，实际上就是在一个grid上执行。</p>
<ul>
<li>Example: vector add
以下代码演示了在CPU和GPU上执行向量加法的过程：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">VecAddCPU</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">VecAddKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>从GPU版本我们可以看到，每个线程执行的指令都是相同，不同的是每个线程具有不同的环境变量。</p>
<p>为了执行上述GPU代码，在主机端要执行以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">VecAddCUDA</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">Acpu</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">Bcpu</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">Ccpu</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="o">*</span><span class="n">dA</span><span class="p">,</span> <span class="o">*</span><span class="n">dB</span><span class="p">,</span> <span class="o">*</span><span class="n">dC</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMalloc</span><span class="p">(</span><span class="o">&amp;</span><span class="n">dA</span><span class="p">,</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMalloc</span><span class="p">(</span><span class="o">&amp;</span><span class="n">dB</span><span class="p">,</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMalloc</span><span class="p">(</span><span class="o">&amp;</span><span class="n">dC</span><span class="p">,</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMemcpy</span><span class="p">(</span><span class="n">dA</span><span class="p">,</span> <span class="n">Acpu</span><span class="p">,</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span> <span class="n">cudaMemcpyHostToDevice</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMemcpy</span><span class="p">(</span><span class="n">dB</span><span class="p">,</span> <span class="n">Bcpu</span><span class="p">,</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span> <span class="n">cudaMemcpyHostToDevice</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">threads_per_block</span> <span class="o">=</span> <span class="mi">512</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">nblocks</span> <span class="o">=</span> <span class="p">(</span><span class="n">n</span> <span class="o">+</span> <span class="n">threads_per_block</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">threads_per_block</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">VecAddKernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">nblocks</span><span class="p">,</span> <span class="n">threads_per_block</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">dA</span><span class="p">,</span> <span class="n">dB</span><span class="p">,</span> <span class="n">dC</span><span class="p">,</span> <span class="n">n</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaMemcpy</span><span class="p">(</span><span class="n">Ccpu</span><span class="p">,</span> <span class="n">dC</span><span class="p">,</span> <span class="n">n</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">),</span> <span class="n">cudaMemcpyDeviceToHost</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="nf">cudaFree</span><span class="p">(</span><span class="n">dA</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaFree</span><span class="p">(</span><span class="n">dB</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="nf">cudaFree</span><span class="p">(</span><span class="n">dC</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>函数的输入是来自cpu内存上的三个数组，在GPU上分配出对应大小的显存，然后将两个加数拷贝到设备中。根据数据的规模确定要启用的block数量，然后执行GPU代码，最后将结果拷贝会CPU内存并释放相应显存。</p>
<p>在实际中，内存拷贝是一个非常耗时的过程，因此我们希望将数据一直保留在显存中进行计算，而非频繁地来回拷贝。</p>
<ul>
<li>Example: window sum
window sum是一种权重全为1的卷积，一种朴素的想法是这么些的：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="cp">#define RADIUS 2
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">WindowSumSimpleKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">B</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">out_idx</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">out_idx</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">float</span> <span class="n">sum</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">dx</span> <span class="o">=</span> <span class="o">-</span><span class="n">RADIUS</span><span class="p">;</span> <span class="n">dx</span> <span class="o">&lt;=</span> <span class="n">RADIUS</span><span class="p">;</span> <span class="o">++</span><span class="n">dx</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">sum</span> <span class="o">+=</span> <span class="n">A</span><span class="p">[</span><span class="n">dx</span> <span class="o">+</span> <span class="n">out_idx</span> <span class="o">+</span> <span class="n">RADIUS</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">B</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">sum</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>但显然，这个算法并不高效，将重复访问数据，要加载$5n$次数据。</p>
<p>这时候可以引入共享内存进行优化，将一个block内要要用到的数据全部读取到共享内存中。数据加载的任务可以分给每个线程并行完成，显著降低了内存加载时间开销。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">WindowSumSharedKernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">temp</span><span class="p">[</span><span class="n">THREADS_PER_BLOCK</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">RADIUS</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">base</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">out_idx</span> <span class="o">=</span> <span class="n">base</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">base</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">temp</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">base</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">RADIUS</span> <span class="o">&amp;&amp;</span> <span class="n">base</span> <span class="o">+</span> <span class="n">THREADS_PER_BLOCK</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">temp</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">THREADS_PER_BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">base</span> <span class="o">+</span> <span class="n">THREADS_PER_BLOCK</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span><span class="n">out_idx</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="kt">float</span> <span class="n">sum</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">dx</span> <span class="o">=</span> <span class="o">-</span><span class="n">RADIUS</span><span class="p">;</span> <span class="n">dx</span> <span class="o">&lt;=</span> <span class="n">RADIUS</span><span class="p">;</span> <span class="o">++</span><span class="n">dx</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">sum</span> <span class="o">+=</span> <span class="n">temp</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">dx</span> <span class="o">+</span> <span class="n">RADIUS</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="n">B</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">sum</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>通过<code>__syncthreads</code>同步，确保所有线程都将数据加载完毕，然后再计算window sum。</p>
<h2 id="case-study-matrix-multiplication-on-gpu">Case study: matrix multiplication on GPU</h2>
<p>从线程的细粒度来说，我们可以在GPU上实现一个寄存器分块版本的矩阵乘法：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">mm</span><span class="p">(</span><span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">],</span> <span class="kt">float</span> <span class="n">B</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">],</span> <span class="kt">float</span> <span class="n">C</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">])</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">ybase</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">xbase</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">c</span><span class="p">[</span><span class="n">V</span><span class="p">][</span><span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">};</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">a</span><span class="p">[</span><span class="n">V</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">a</span><span class="p">[</span><span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">ybase</span><span class="o">*</span><span class="nl">V</span> <span class="p">:</span> <span class="n">ybase</span><span class="o">*</span><span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">b</span><span class="p">[</span><span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">xbase</span><span class="o">*</span><span class="nl">V</span> <span class="p">:</span> <span class="n">xbase</span><span class="o">*</span><span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">y</span> <span class="o">&lt;</span> <span class="n">V</span><span class="p">;</span> <span class="o">++</span><span class="n">y</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">x</span> <span class="o">&lt;</span> <span class="n">V</span><span class="p">;</span> <span class="o">++</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="n">c</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="n">x</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">[</span><span class="n">y</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">[</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">C</span><span class="p">[</span><span class="n">ybase</span> <span class="o">*</span> <span class="nl">V</span> <span class="p">:</span> <span class="n">ybase</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">,</span> <span class="n">xbase</span> <span class="o">*</span> <span class="nl">V</span> <span class="p">:</span> <span class="n">xbase</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="n">c</span><span class="p">[</span><span class="o">:</span><span class="p">,</span><span class="o">:</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>每个线程负责计算一个分块的结果，即每次计算下图中的一块。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407261324561.png?x-oss-process=image/quality,q_90/format,webp">
还可以将计算一块的任务交给一个block，这样就可以使用共享内存技术有block内的线程共同加载要用到的数据。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">mm</span><span class="p">(</span><span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">],</span> <span class="kt">float</span> <span class="n">B</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">],</span> <span class="kt">float</span> <span class="n">C</span><span class="p">[</span><span class="n">N</span><span class="p">][</span><span class="n">N</span><span class="p">])</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">sA</span><span class="p">[</span><span class="n">S</span><span class="p">][</span><span class="n">L</span><span class="p">],</span> <span class="n">sB</span><span class="p">[</span><span class="n">S</span><span class="p">][</span><span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">c</span><span class="p">[</span><span class="n">V</span><span class="p">][</span><span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">};</span>
</span></span><span class="line"><span class="cl">    <span class="kt">float</span> <span class="n">a</span><span class="p">[</span><span class="n">V</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">yblock</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">xblock</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ko</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ko</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">ko</span> <span class="o">+=</span> <span class="n">S</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="c1">// needs to be implemented by thread cooperative fetching
</span></span></span><span class="line"><span class="cl">        <span class="n">sA</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">ko</span> <span class="o">+</span> <span class="n">S</span><span class="p">,</span> <span class="n">yblock</span> <span class="o">*</span> <span class="nl">L</span> <span class="p">:</span> <span class="n">yblock</span> <span class="o">*</span> <span class="n">L</span> <span class="o">+</span> <span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="n">sB</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">ko</span> <span class="o">+</span> <span class="n">S</span><span class="p">,</span> <span class="n">xblock</span> <span class="o">*</span> <span class="nl">L</span> <span class="p">:</span> <span class="n">xblock</span> <span class="o">*</span> <span class="n">L</span> <span class="o">+</span> <span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="nf">__syncthreads</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ki</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ki</span> <span class="o">&lt;</span> <span class="n">S</span><span class="p">;</span> <span class="o">++</span><span class="n">ki</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="n">a</span><span class="p">[</span><span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">sA</span><span class="p">[</span><span class="n">ki</span><span class="p">,</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="n">b</span><span class="p">[</span><span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">sB</span><span class="p">[</span><span class="n">ki</span><span class="p">,</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">y</span> <span class="o">&lt;</span> <span class="n">V</span><span class="p">;</span> <span class="o">++</span><span class="n">y</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">x</span> <span class="o">&lt;</span> <span class="n">V</span><span class="p">;</span> <span class="o">++</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">                    <span class="n">c</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="n">x</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a</span><span class="p">[</span><span class="n">y</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">[</span><span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">                <span class="p">}</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">ybase</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">xbase</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">C</span><span class="p">[</span><span class="n">ybase</span> <span class="o">*</span> <span class="nl">V</span> <span class="p">:</span> <span class="n">ybase</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">,</span> <span class="n">xbase</span> <span class="o">*</span> <span class="nl">V</span> <span class="p">:</span> <span class="n">xbase</span> <span class="o">*</span> <span class="n">V</span> <span class="o">+</span> <span class="n">V</span><span class="p">]</span> <span class="o">=</span> <span class="n">c</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码从全部内存到共享内存的加载过程被复用L次（计算每个分块矩阵都要读取L次AB的行列向量），从共享内存到寄存器被复用V次（在分块矩阵中按照长度V进行了二次分块计算）
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202407261448550.png?x-oss-process=image/quality,q_90/format,webp">
各线程读取数据到共享内存的过程为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-c" data-lang="c"><span class="line"><span class="cl"><span class="n">sA</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="nl">k</span> <span class="p">:</span> <span class="n">k</span> <span class="o">+</span> <span class="n">S</span><span class="p">,</span> <span class="n">yblock</span> <span class="o">*</span> <span class="nl">L</span> <span class="p">:</span> <span class="n">yblock</span> <span class="o">*</span> <span class="n">L</span> <span class="o">+</span> <span class="n">L</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="n">nthreads</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="n">tid</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">L</span> <span class="o">*</span> <span class="n">S</span> <span class="o">/</span> <span class="n">nthreads</span><span class="p">;</span> <span class="o">++</span><span class="n">j</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">y</span> <span class="o">=</span> <span class="p">(</span><span class="n">j</span> <span class="o">*</span> <span class="n">nthreads</span> <span class="o">+</span> <span class="n">tid</span><span class="p">)</span> <span class="o">/</span> <span class="n">L</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">j</span> <span class="o">*</span> <span class="n">nthreads</span> <span class="o">+</span> <span class="n">tid</span><span class="p">)</span> <span class="o">%</span> <span class="n">L</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">s</span><span class="p">[</span><span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">k</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">yblock</span> <span class="o">*</span> <span class="n">L</span> <span class="o">+</span> <span class="n">x</span><span class="p">];</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lecture-13-hardware-acceleration-implemetation">Lecture 13: Hardware Acceleration Implemetation</h1>
<p>这节是实验课，在这节课中，我们将学习needle库中CPU和GPU底端具体实现的代码骨架。</p>
<p>这节课不做笔记，本节课内容可通过完成hw3学习。</p>
<h1 id="lecture-14-implementing-convolutions">Lecture 14: Implementing Convolutions</h1>
<p>本节课将学习卷积算子的具体实现。</p>
<h2 id="存储格式-storage-order">存储格式 Storage Order</h2>
<p>对于图片数据或者隐藏层，我们需要存储<code>batch_size*channel*height*width</code>即<code>B*C*H*W</code>个元素，本课程中，我们选取的存储格式为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">float</span> <span class="n">Z</span><span class="p">[</span><span class="n">BATCHES</span><span class="p">][</span><span class="n">HEIGHT</span><span class="p">][</span><span class="n">WIDTH</span><span class="p">][</span><span class="n">CHANNELS</span><span class="p">];</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述格式被称为NHWC格式（N代表number）。PyTorch默认格式为NCHW，其在后期版本也支持NHWC。不同的格式会影响操作的性能：卷积在NHWC上更快，Batch Norm在NCHW上更快。</p>
<p>对于卷积核，其需要存储<code>k*k*C_in*C_out</code>个元素，本课程我们选取的存储格式为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">float</span> <span class="n">weights</span><span class="p">[</span><span class="n">KERNEL_SIZE</span><span class="p">][</span><span class="n">KERNEL_SIZE</span><span class="p">][</span><span class="n">IN_CHANNELS</span><span class="p">][</span><span class="n">OUT_CHANNELS</span><span class="p">];</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>PyTorch选择的格式为<code>(C_out, C_in, k, k)</code>。</p>
<h2 id="for循环实现卷积-convolutions-with-simple-loops">for循环实现卷积 Convolutions with simple loops</h2>
<p>通过循环来实现卷积操作的过程，从外到内，循环迭代的参数依次为：batch、channel_in、channel_out、out_row、out_column，还有两个循环用于实现卷积，共七个循环：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">conv_naive</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">weight</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">N</span><span class="p">,</span><span class="n">H</span><span class="p">,</span><span class="n">W</span><span class="p">,</span><span class="n">C_in</span> <span class="o">=</span> <span class="n">Z</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">K</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">C_out</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="n">out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span><span class="n">H</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">W</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">C_out</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">c_in</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">C_in</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">c_out</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">C_out</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">H</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                    <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">W</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                            <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                                <span class="n">out</span><span class="p">[</span><span class="n">n</span><span class="p">,</span><span class="n">y</span><span class="p">,</span><span class="n">x</span><span class="p">,</span><span class="n">c_out</span><span class="p">]</span> <span class="o">+=</span> <span class="n">Z</span><span class="p">[</span><span class="n">n</span><span class="p">,</span><span class="n">y</span><span class="o">+</span><span class="n">i</span><span class="p">,</span><span class="n">x</span><span class="o">+</span><span class="n">j</span><span class="p">,</span><span class="n">c_in</span><span class="p">]</span> <span class="o">*</span> <span class="n">weight</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">j</span><span class="p">,</span><span class="n">c_in</span><span class="p">,</span><span class="n">c_out</span><span class="p">]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>该七重循环实现的卷积耗时3秒，而PyTorch仅需1.2毫秒，约2500倍的性能差距。</p>
<h2 id="矩阵乘法实现卷积-convolutions-as-matrix-multiplications">矩阵乘法实现卷积 Convolutions as matrix multiplications</h2>
<p>卷积核中任意一个元素[ i, j, :, : ]都是一个shape为(c_in, c_out)的矩阵，当其作用在输入图片的某个元素(p,q,m,:)即作用在一个长度为c_in的向量上时，这个过程就是一个矩阵乘法运算。</p>
<p>特别的，对于卷积核大小为1×1的情况，整个卷积过程可以直接用一个矩阵乘法来表示：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">W1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">8</span><span class="p">,</span><span class="mi">16</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">out</span> <span class="o">=</span> <span class="n">conv_reference</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span><span class="n">W1</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>怎么将1×1的卷积核推广到一般情况呢？可以把卷积核看成由一个个1×1的小卷积核组成的，不断迭代这些卷积核即可。需要注意的是，每个小卷积核在图片上作用的范围都不一样，要做好切片：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">conv_matrix_mult</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">weight</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">N</span><span class="p">,</span><span class="n">H</span><span class="p">,</span><span class="n">W</span><span class="p">,</span><span class="n">C_in</span> <span class="o">=</span> <span class="n">Z</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">K</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">C_out</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span><span class="n">H</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">W</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">C_out</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">            <span class="n">out</span> <span class="o">+=</span> <span class="n">Z</span><span class="p">[:,</span><span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="n">H</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">j</span><span class="p">:</span><span class="n">j</span><span class="o">+</span><span class="n">W</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">@</span> <span class="n">weight</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">j</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">out</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>该版本卷积耗时17毫秒，相比PyTorch1.2毫秒，约14倍性能差距。</p>
<h2 id="通过strides来操作矩阵-manipulating-matrices-via-strides">通过strides来操作矩阵 Manipulating matrices via strides</h2>
<p>在内存中，通常将矩阵按照二维数组的形式在内存中存储：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">M</span><span class="p">][</span><span class="n">N</span><span class="p">];</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>但是，我们在实现一些高效算子时，经常会把矩阵分块，如果将其分块存储，那么这些算子将会具有更好的空间局部性：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">float</span> <span class="n">A</span><span class="p">[</span><span class="n">M</span><span class="o">/</span><span class="n">TILE</span><span class="p">][</span><span class="n">N</span><span class="o">/</span><span class="n">TILE</span><span class="p">][</span><span class="n">TILE</span><span class="p">][</span><span class="n">TILE</span><span class="p">]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>NumPy提供了一个函数用于实现从二维数组转变为分块矩阵的格式：<code>np.lib.stride_tricks.as_strided</code><sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup>。</p>
<p>具体来说，<code>as_strided</code>这个函数用于创建一个具有不同shape和strides，但与原array具有相同底层数据的视图（view）。</p>
<p>举个例子，如下图所示，一个6×6的矩阵，对于按照2×2进行分块，我们从strides[3]倒着写出其值。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408252356974.webp?x-oss-process=image/quality,q_90/format,webp"></p>
<ul>
<li>strides [3]表示在子矩阵内部移动到下一列的元素的步长，即从0移动到1的步长，数据在内存中是按行连续排列的，因此其值为1；</li>
<li>strides [2]表示在子矩阵中移动到下一行元素的步长，即从0移动到6的所需步长，观察图片可以看到该步步长等于矩阵的列数N，即6；</li>
<li>strides [1]表示从一个子矩阵移动到同行下一个子矩阵的对应位置的步长，即从0移动到2的步长，可以看到移动的步长等于分块的列长度TILE，即2；</li>
<li>strides [0]表示从一个子矩阵移动到同列下一个子矩阵对应位置的步长，即从0移动到12的步长，可以看到移动的步长等于TILE*N，即12。</li>
</ul>
<p>确定了strides之后，就可以使用<code>as_strided</code>为原矩阵创建一个分块矩阵的视图，并使用<code>np.ascontiguousarray</code>创建一个内存连续版本的副本：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
</span></span><span class="line"><span class="cl"><span class="n">n</span> <span class="o">=</span> <span class="mi">6</span>
</span></span><span class="line"><span class="cl"><span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n</span><span class="p">,</span><span class="n">n</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">lib</span><span class="o">.</span><span class="n">stride_tricks</span><span class="o">.</span><span class="n">as_strided</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">((</span><span class="mi">12</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">6</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span>  <span class="c1">#numpy中strides以字节为单位</span>
</span></span><span class="line"><span class="cl"><span class="n">C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>&mdash;&mdash;&mdash;&mdash;&mdash;-以下非课程内容&mdash;&mdash;&mdash;&mdash;&mdash;-
这里插一嘴，这里实现分块的方式非常不优雅，毕竟numpy并不建议使用这么底层的API来直接修改数据，我问了下GPT，他提供了一种更优雅的方案。</p>
<p>我们首先可以将原矩阵(M, N)reshape为(M//TILE, TILE, N//TILE, TILE)，这一步相当于将原矩阵在行和列上进行分块，并且(p,m,q,n)表示第p行第q列的子矩阵中第m行第n列个元素。然后使用<code>transpose(0, 2, 1, 3)</code>重新排列维度即可。</p>
<p>至于为什么reshape那一步后索引仍是正确的，我略微理解的，但难以表达出来，有点只可意会的意思：reshape那个操作可以分成两步完成，分别是在行和列上进行切片，这两个步骤又不冲突，合并后的结果就是如下的代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">block_matrix</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">TILE</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="n">M</span> <span class="o">%</span> <span class="n">TILE</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">N</span> <span class="o">%</span> <span class="n">TILE</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&#34;矩阵维度必须能被TILE整除&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="n">A_reshaped</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">M</span><span class="o">//</span><span class="n">TILE</span><span class="p">,</span> <span class="n">TILE</span><span class="p">,</span> <span class="n">N</span><span class="o">//</span><span class="n">TILE</span><span class="p">,</span> <span class="n">TILE</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">A_blocked</span> <span class="o">=</span> <span class="n">A_reshaped</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">A_blocked</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>&mdash;&mdash;&mdash;&mdash;&mdash;-以上非课程内容&mdash;&mdash;&mdash;&mdash;&mdash;-</p>
<h2 id="通过-im2col-来实现卷积-convolutions-via-im2col">通过 im2col 来实现卷积 Convolutions via im2col</h2>
<p>在Lecture 10中提到，我们可以使用im2col技术，将一维卷积运算转换为矩阵运算：


<div>$$

\begin{bmatrix}z_1\\z_2\\z_3\\z_4\\z_5\end{bmatrix}=x*w=\begin{bmatrix}0&amp;x_1&amp;x_2\\x_1&amp;x_2&amp;x_3\\x_2&amp;x_3&amp;x_4\\x_3&amp;x_4&amp;x_5\\x_4&amp;x_5&amp;0\end{bmatrix}\begin{bmatrix}w_1\\w_2\\w_3\end{bmatrix}

$$</div>

对于二维卷积来说，同样也是可以的。以卷积核大小为3×3为例，对6×6的矩阵进行卷积，其结果矩阵为4×4。首先，我们找出每次运算的感受野，将其单独拿出来，那么所有这些感受野就组成了一个4×4×3×3的Tensor。</p>
<p>如下图所示，第[0,0]个感受野就是[0,1,2;6,7,8;12,13,14]。怎么将原始矩阵转变为Tensor呢？这里就可以用到上节提到的<code>as_strided</code>方法。strides[0]表示到同列下一个感受野的相同位置的元素的步长，为列长6；strides[1]表示到同行下一个感受野的步长，为1；strides[2]表示同一个感受野内部同列下一个元素的步长，为原始列长6；strides[3]表示同一个感受野内部同行下一个元素的步长，为1。即，使用<code>B = np.lib.stride_tricks.as_strided(A, shape=(4,4,3,3), strides=4*(np.array((6,1,6,1))))</code>可以将原始待卷积矩阵A转变为感受野张量B。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408300948305.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>下一步，通过reshape操作将单个感受野和卷积核都转变为向量，通过内积运算计算卷积值：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="p">(</span><span class="n">B</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span><span class="mi">9</span><span class="p">)</span> <span class="o">@</span> <span class="n">W</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">9</span><span class="p">))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>需要注意的是，B的reshape的操作并不是free的，无法通过原始的A的数据来表示reshape后的B，该reshape操作会分配出一块$O(K^2)$的内存空间，当K比较大时，这个操作将相当耗费内存。因此，在现代版本中，常常会使用lazy技术或者其它技术，但这不在本课程讨论范围之内。</p>
<h2 id="通过-im2col-来实现多通道卷积">通过 im2col 来实现多通道卷积</h2>
<p>对于多通道并且考虑batch的卷积，其输入shape为N×H×W×C_in，感受野Tensor为N×(W-K+1)×(H-K+1)×K×K×C_in，需要将K×K×C_in展开为一维，卷积核也要将对应位置展开，即reshape后shape为(K×K×C_in)×C_out。</p>
<p>代码实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">conv_im2col</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">weight</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">N</span><span class="p">,</span><span class="n">H</span><span class="p">,</span><span class="n">W</span><span class="p">,</span><span class="n">C_in</span> <span class="o">=</span> <span class="n">Z</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">K</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">C_out</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">Ns</span><span class="p">,</span> <span class="n">Hs</span><span class="p">,</span> <span class="n">Ws</span><span class="p">,</span> <span class="n">Cs</span> <span class="o">=</span> <span class="n">Z</span><span class="o">.</span><span class="n">strides</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="n">inner_dim</span> <span class="o">=</span> <span class="n">K</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="n">C_in</span>
</span></span><span class="line"><span class="cl">    <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">lib</span><span class="o">.</span><span class="n">stride_tricks</span><span class="o">.</span><span class="n">as_strided</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">H</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">W</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">C_in</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">                                        <span class="n">strides</span> <span class="o">=</span> <span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">Hs</span><span class="p">,</span> <span class="n">Ws</span><span class="p">,</span> <span class="n">Hs</span><span class="p">,</span> <span class="n">Ws</span><span class="p">,</span> <span class="n">Cs</span><span class="p">))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="n">inner_dim</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">out</span> <span class="o">=</span> <span class="n">A</span> <span class="o">@</span> <span class="n">weight</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">C_out</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">N</span><span class="p">,</span><span class="n">H</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">W</span><span class="o">-</span><span class="n">K</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span><span class="n">C_out</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lecture-15-training-large-models">Lecture 15: Training Large Models</h1>
<h2 id="内存节省技术-techniques-for-memory-saving">内存节省技术 Techniques for memory saving</h2>
<p>一直以来，GPU的全局内存大小都是模型大小的制约瓶颈，通过一些内存节省技术可以训练更大的一些模型。</p>
<p>模型内存消耗主要有如下几个方面：模型权重、优化器状态（动量值等等）、中间激活层的值。</p>
<p>对于推理来说，保存激活层的内存只需要两块，分别用来保存一层的输入和输出，下一层的输入为上一层的输出，下一层的输出覆盖上一层的输入。其不需要保存中间激活层的值。</p>
<p>而在训练中，由于在计算每一层的梯度时都用到了该层的输入，所每一个激活层都要一片内存保存下来，即激活层的内存数量为$O(N)$，如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408302230612.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>一种减少激活层内存使用的技术叫做checkpoint，就是每隔一个激活层才保存该层的值，如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408302240185.png?x-oss-process=image/quality,q_90/format,webp">
在反向传播时，如果需要用到未保存的隐藏层，则通过上一个隐藏层计算出该层的值即可。这是一种时间换空间的思路。对于一个N层的网络，每隔K个隐藏层保存一次结果，则隐藏层占用的内存空间大小为$O(N/K)+O(K)$，当$K=\sqrt{N}$时可取到最小值。</p>
<h2 id="并行和分布式训练-parallel-and-distributed-training">并行和分布式训练 Parallel and distributed training</h2>
<ul>
<li>计算图划分</li>
</ul>
<p>当有多个GPU时，可以进行并行分布式训练。一种思路是将计算图进行划分，并分配给不同的worker进行执行，通过通讯协议在worker中间传递数据。如下图所示，整个计算图被划分为两部分。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202408310842410.png?x-oss-process=image/quality,q_90/format,webp">
仅仅将计算图划分并不会带来多少的并行性，但是当worker1计算来自worker0的数据时，worker0可以并行计算下一个minibatch的数据，从而实现高并行。</p>
<ul>
<li>数据并行训练</li>
</ul>
<p>数据并行训练是的是将一个minibatch分割成更小的smaller batch，每个GPU负责一个smaller batch的计算，这样做每隔GPU上都在跑相同的模型。</p>
<p>在分布式和并行计算中，有一个allreduce原语，其作用是将分布在多个进程或节点上的数据进行规约（reduction）操作，然后将结果广播回所有参与的进程或节点。运用这个原语，我们可以在多GPU上计算出smaller batch的梯度，然后利用该原语将计算出整个minibatch的梯度并进行梯度下降。</p>
<p>我们还可以将参数使用专门的参数服务器保存，其它设备需要访问或者更新参数时，只需要调用相应API即可。参数服务器的好处是其不需要等待所有的worker都计算结束再更新，支持动态增减worker数量，提高了系统的鲁棒性。</p>
<ul>
<li>通信计算重叠 communication computation overlap</li>
</ul>
<p>通信计算重叠，就是指在通信同步时使用非阻塞的方式，在等待IO时继续计算。</p>
<h1 id="lecture-16-generative-adversarial-network">Lecture 16 Generative Adversarial Network</h1>
<h2 id="生成对抗训练-generative-adversarial-training">生成对抗训练 Generative adversarial training</h2>
<p>对于无监督学习，或者称生成式模型，其任务是通过随机向量生成符合数据集分布的样本。这就引入了一个问题：如何评估样本和目标分布之间的距离。这一评价指标作为我们的目标函数，其必须是可微的，以便后续对模型进行优化。</p>
<p>对抗训练的思路是构造一个oracle classfier D，其作用是辨别生成数据和原始数据，D的输出是输入为生成数据为生成数据的概率。那对于任意一个输入z，生成网络G的输出为G(z)，D对其的判别结果为D(G(z))。那生成器的目标就是尽可能让判别器判别错误，即其损失函数为：


<div>$$

\max_G\{-E_{z\sim Noise}\log{(1-D(G(z)))}\}

$$</div>
</p>
<p>需要注意的是，这里并没有现成的辨别器D。我们同样可以用一个神经网络来构造这个辨别器，那这个辨别器的目标就是尽可能判断正确，即其损失函数为：


<div>$$

\min_D\{-E_{z\sim Noise}\log{(1-D(G(z)))}-E_{x\sim Data}\log{D(x)}\}

$$</div>
</p>
<h2 id="将对抗训练作为深度学习中的一个模块-adversarial-training-as-a-module-in-deep-learning-models">将对抗训练作为深度学习中的一个模块 Adversarial training as a module in deep learning models</h2>
<p>接下来我们考虑如何将对抗模型模块化。我们可以将整个判别器作为一个损失函数来实现，当然，其和我们之前实现的损失函数是不一样的，判别器的参数在每轮反向传播时都要更新。</p>
<p>【这一节课似乎没有具体说明如何模块化，后边似乎在介绍GAN网络的各个变种】</p>
<p>在DCGAN中，使用了一种被称为反卷积（转置卷积、Conv2dTranspose）的模块，其作用是进行上采样。</p>
<p>CycleGAN是一个用于风格迁移的模型。对于风格迁移模型来说，一种有监督的训练思路是收集风格迁移前后的图片配对数据集，进行有监督训练。然而，此类配对数据集是很难获取的。如何通过未配对的数据集进行无监督训练呢？可以使用GAN网络，一个生成器G用于升成风格迁移后的图片，使用一个判别器进行对抗训练。另外有一个生成器F，用于还原图片，其也使用一个判别器进行对抗训练。而整个CycleGAN模型还需要保证循环一致性，即将数据集中的一个图片经过G之后，再经过F，应当还原成原始图片，故循环一致性的损失函数就是两个图片之间的L2 Norm。</p>
<p>在下节课中，将讨论GAN系列网络的具体实现。</p>
<h1 id="lecture-17-generative-adversarial-networks-implementations">Lecture 17: Generative Adversarial Networks implementations</h1>
<p>本节课中，我们将学习GAN网络的具体实现。</p>
<p>在课程中，使用二维高斯分布作为真实数据集，训练一个生成器用于升成该分布的数据。训练集数据准备如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
</span></span><span class="line"><span class="cl"><span class="n">mu</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl"><span class="c1"># total number of sample data to generated</span>
</span></span><span class="line"><span class="cl"><span class="n">num_sample</span> <span class="o">=</span> <span class="mi">3200</span>
</span></span><span class="line"><span class="cl"><span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">num_sample</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="o">@</span> <span class="n">A</span> <span class="o">+</span> <span class="n">mu</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>生成器使用一个简单的全连接层即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">model_G</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">sample_G</span><span class="p">(</span><span class="n">model_G</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">Z</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">num_samples</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">    <span class="n">fake_X</span> <span class="o">=</span> <span class="n">model_G</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">fake_X</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>判别器是一个三层的感知机，损失函数为softmax loss：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">model_D</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">loss_D</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SoftmaxLoss</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>优化生成器G的过程就是使用G随机生成一些数据G(z)，计算D(G(z))的输出和label 1之间的损失：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">opt_G</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model_G</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span><span class="err">、</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">update_G</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">model_G</span><span class="p">,</span> <span class="n">model_D</span><span class="p">,</span> <span class="n">loss_D</span><span class="p">,</span> <span class="n">opt_G</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">fake_X</span> <span class="o">=</span> <span class="n">model_G</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">fake_Y</span> <span class="o">=</span> <span class="n">model_D</span><span class="p">(</span><span class="n">fake_X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_size</span> <span class="o">=</span> <span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="n">ones</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;int32&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_D</span><span class="p">(</span><span class="n">fake_Y</span><span class="p">,</span> <span class="n">ones</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">opt_G</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>同样，判别器的更新过程就是计算D(x)和label 1之间的损失，D(G(z))和label 0之间的损失，x是真实数据：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">opt_D</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model_D</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">update_D</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">model_G</span><span class="p">,</span> <span class="n">model_D</span><span class="p">,</span> <span class="n">loss_D</span><span class="p">,</span> <span class="n">opt_D</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">fake_X</span> <span class="o">=</span> <span class="n">model_G</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">fake_Y</span> <span class="o">=</span> <span class="n">model_D</span><span class="p">(</span><span class="n">fake_X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">real_Y</span> <span class="o">=</span> <span class="n">model_D</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">Z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="n">batch_size</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="n">ones</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;int32&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">zeros</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;int32&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_D</span><span class="p">(</span><span class="n">real_Y</span><span class="p">,</span> <span class="n">ones</span><span class="p">)</span> <span class="o">+</span> <span class="n">loss_D</span><span class="p">(</span><span class="n">fake_Y</span><span class="p">,</span> <span class="n">zeros</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">opt_D</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>训练过程则是每次迭代中，将随机向量送入生成器，再将生成器的输出喂给判别器，然后分别更新二者的参数即可，注意以下代码中epoch指的是训练了几个batch，而不是指在训练集上完整训练了几轮：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train_gan</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">assert</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl">    <span class="n">data</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">begin</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">epoch</span><span class="p">)</span> <span class="o">%</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">X</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">begin</span><span class="p">:</span> <span class="n">begin</span><span class="o">+</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">:]</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        <span class="n">X</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">Z</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">update_D</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">model_G</span><span class="p">,</span> <span class="n">model_D</span><span class="p">,</span> <span class="n">loss_D</span><span class="p">,</span> <span class="n">opt_D</span><span class="p">)</span> 
</span></span><span class="line"><span class="cl">        <span class="n">update_G</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">model_G</span><span class="p">,</span> <span class="n">model_D</span><span class="p">,</span> <span class="n">loss_D</span><span class="p">,</span> <span class="n">opt_G</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">train_gan</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">2000</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>以上就是训练一个GAN网络的全过程，接下来我们考虑如何GAN Loss模块化。GAN Loss的作用是给定一个生成器的输出，返回一个损失值。此外，当生成器拿到损失值后就会直接进行生成器的参数更新，因此GAN Loss内部必须隐式更新自身的参数，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">GANLoss</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_D</span><span class="p">,</span> <span class="n">opt_D</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">model_D</span> <span class="o">=</span> <span class="n">model_D</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">opt_D</span> <span class="o">=</span> <span class="n">opt_D</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">loss_D</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SoftmaxLoss</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">_update_D</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">real_X</span><span class="p">,</span> <span class="n">fake_X</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">real_Y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_D</span><span class="p">(</span><span class="n">real_X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">fake_Y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_D</span><span class="p">(</span><span class="n">fake_X</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">real_X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">ones</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">zeros</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_D</span><span class="p">(</span><span class="n">real_Y</span><span class="p">,</span> <span class="n">ones</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_D</span><span class="p">(</span><span class="n">fake_Y</span><span class="p">,</span> <span class="n">zeros</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">opt_D</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fake_X</span><span class="p">,</span> <span class="n">real_X</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">_update_D</span><span class="p">(</span><span class="n">real_X</span><span class="p">,</span> <span class="n">fake_X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">fake_Y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_D</span><span class="p">(</span><span class="n">fake_X</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">real_X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">ones</span> <span class="o">=</span> <span class="n">ndl</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&#34;float32&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_D</span><span class="p">(</span><span class="n">fake_Y</span><span class="p">,</span> <span class="n">ones</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">loss</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lecture-18-sequence-modeling-and-recurrent-networks">Lecture 18: Sequence Modeling and Recurrent Networks</h1>
<h2 id="序列建模-sequence-modeling">序列建模 Sequence modeling</h2>
<p>在前面的模型中，我们都做了一个隐式假设：x和y之间是独立同分布的，但是在实践中，很多任务的y都是与x相关的，尤其是当y是一个时间序列数据。</p>
<p>对于序列数据来说，有一类预测模型是自回归模型，其基本思想是利用序列自身的历史值来预测未来值。</p>
<h2 id="循环神经网络-recurrent-neural-networks">循环神经网络 Recurrent neural networks</h2>
<p>循环网络也能用于解决序列数据的建模问题。RNN网络的思想是构建一个网络模型用于模拟输入序列中的时序信息。</p>
<p>如下图所示，h表示模型中的隐藏层，隐藏层的输入为前一个隐藏层的状态和当前输入x，经过非线性变换后得到该隐藏层，当前输入的对应输出则由对应隐藏层经过非线性变换后得到。
即：


<div>$$

\begin{align*}  
h_t &amp;= f(h_{t-1},x_t)\\  
y_t &amp;=g(h)  
\end{align*}

$$</div>
</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409041829344.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>理论上来说，如果建模得当，这种模式在预测$y_t$时可以获取前面所有时刻的时序信息。</p>
<p>RNN的训练时需要配对的x和y作为数据集，损失函数由每一个预测值和真实值之间的损失累加得到。显然，这个损失函数很难通过笔纸进行推导，但得益于我们之前构建的自动微分系统，我们不需要手动计算任何梯度。</p>
<p>可以将多个RNN堆叠在一起，得到stacking RNN，如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409041908220.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>RNN在训练过程中很容易出现梯度/激活层爆炸和梯度/激活层消失问题。之前的lecture提到，当训练很深的网络时，初始化参数是很重要的。在RNN上，这个问题更加严重，因为RNN的模型通常很深很深。</p>
<p>一个解决梯度问题的方法是着眼于激活函数。ReLU作为激活函数其一个问题是其输出可以无限大。然而，将激活函数修改为有界函数，例如sigmoid或者tanh并不能解决这一问题，尤其是，其不能解决激活层/梯度消失问题。如下所示，对于tanh，当x在0附近时，其输出仍在0附近，这会导致隐藏层消失；但于两个函数，当输入在-5和5附近时，其梯度很小，这会导致梯度消失。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409041923270.png?x-oss-process=image/quality,q_90/format,webp"></p>
<h2 id="lstms">LSTMs</h2>
<p>LSTM一定程度上减轻了RNN中存在的梯度消失和爆炸问题。LSTM在原版RNN的基础上对隐藏层进行了一定改进。如下图所示，LSTM将原始hidden state分裂为两个组件hidden state和cell state。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409041935116.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>其次，LSTM中具体定义了hidden state和cell state的具体更新公式。LSTM中定义了一些中间变量用于更简洁地描述这一公式，中间变量有forget gate、input gate、output gate，还有一个候选状态g_t。这些中间变量和状态的更新公式如下所示：</p>


<div>$$

\begin{align*}  
&amp;\begin{bmatrix}i_t\\f_t\\g_t\\o_t\end{bmatrix}=\begin{pmatrix}\text{sigmoid}\\\text{sigmoid}\\\text{tanh}\\\text{sigmoid}\end{pmatrix}(W_{hh}h_{t-1}&#43;\text{W}_{hx}x_t&#43;b_h) \\  
&amp;c_t=c_{t-1}\circ f_t&#43;i_t\circ g_t \\  
&amp;h_t=\tanh(c_t)\circ o_t\\  
&amp;i_t,f_t,g_t,o_t,c_t,h_t \in \mathbb{R}^d\\  
&amp;W_{hh},W_{hx}\in \mathbb{R}^{4d\times d}  
\end{align*}

$$</div>

<p>$W_{hh},W_{hx}\in \mathbb{R}^{4d\times d}$意味着，计算中间变量的权重彼此都是独立的。</p>
<p>？？？？！！！这公式怎么来的，为啥子这个公式管用？有很多工作试图对此进行解释，但大多是一家之言。Zico Kolter教授对此的解释是：$g_t$在经过sigmoid以后是一个0-1变量，用于决定是否要保留前一状态对应位置的cell state信息，$i_t$同样是个0-1变量，而$g_t$是个有界项，这一组合决定了是否要在cell state的位置上添加一些额外的信息；$h_t$的更新公式则是一个有界变量，其作用是防止梯度爆炸或者消失。</p>
<h2 id="beyond-simple-sequential-models">Beyond &ldquo;simple&rdquo; sequential models</h2>
<p>除了对序列数据进行建模，RNN能做的还有很多。例如，翻译句子，有一种sequence to sequence架构采用了两个RNN模型，一个用于输入原始句子，提取中间状态，另一个用于根据最后一个中间状态，输出翻译后的句子。
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409042233558.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>这意味着，RNN可以作为一个encoder对语义信息进行提取和编码，也可以作为decoder对语义信息进行解码。</p>
<p>RNN有一种变体是双向RNN，其作用是$x_i$时刻的输出与前后都相关，在一些任务，例如完形填空中可以有较好的表现。</p>
<h1 id="lecture-19-lstm-implementation">Lecture 19: LSTM Implementation</h1>
<h2 id="lstm-cell">LSTM cell</h2>
<p>本节课，我们将在NumPy实现LSTM。首先来实现LSTM cell，一个cell是hidden state和cell state的集合，其状态更新公式为：


<div>$$

\begin{align*} \\  
        i_t &amp;= \sigma(W_{ii} x_t &#43; b_{ii} &#43; W_{hi} h_{t-1} &#43; b_{hi}) \\  
        f_t &amp;= \sigma(W_{if} x_t &#43; b_{if} &#43; W_{hf} h_{t-1} &#43; b_{hf}) \\  
        g_t &amp;= \tanh(W_{ig} x_t &#43; b_{ig} &#43; W_{hg} h_{t-1} &#43; b_{hg}) \\  
        o_t &amp;= \sigma(W_{io} x_t &#43; b_{io} &#43; W_{ho} h_{t-1} &#43; b_{ho}) \\  
        c_t &amp;= f_t \odot c_{t-1} &#43; i_t \odot g_t \\  
        h_t &amp;= o_t \odot \tanh(c_t) \\  
    \end{align*}

$$</div>
</p>
<p>上述公式在上节课中，可以记为矩阵的形式，即：


<div>$$

\begin{align*}  
&amp;\begin{bmatrix}i_t\\f_t\\g_t\\o_t\end{bmatrix}=\begin{pmatrix}\text{sigmoid}\\\text{sigmoid}\\\text{tanh}\\\text{sigmoid}\end{pmatrix}(W_{hh}h_{t-1}&#43;\text{W}_{hx}x_t&#43;b_h) \\  
&amp;c_t=c_{t-1}\circ f_t&#43;i_t\circ g_t \\  
&amp;h_t=\tanh(c_t)\circ o_t\\  
&amp;i_t,f_t,g_t,o_t,c_t,h_t \in \mathbb{R}^d\\  
&amp;W_{hh},W_{hx}\in \mathbb{R}^{4d\times d}  
\end{align*}

$$</div>
</p>
<p>在PyTorch中，已经有LSTM的具体实现，当我们实例化一个$20\times100$的cell，即输入向量长度为20，中间状态特征长度为100，那$W_{hh}$和$W_{hx}$的形状就是$400\times 100$和$400\times 20$。</p>
<p>根据上述更新公式，可以得到计算一个LSTM cell 的方法：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="mi">1</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span><span class="o">+</span><span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">lstm_cell</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">W_hh</span><span class="p">,</span> <span class="n">W_ih</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">i</span><span class="p">,</span><span class="n">f</span><span class="p">,</span><span class="n">g</span><span class="p">,</span><span class="n">o</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">W_ih</span><span class="nd">@x</span> <span class="o">+</span> <span class="n">W_hh</span><span class="nd">@h</span> <span class="o">+</span> <span class="n">b</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">i</span><span class="p">,</span><span class="n">f</span><span class="p">,</span><span class="n">g</span><span class="p">,</span><span class="n">o</span> <span class="o">=</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">i</span><span class="p">),</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">f</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">g</span><span class="p">),</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">c_out</span> <span class="o">=</span> <span class="n">f</span><span class="o">*</span><span class="n">c</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="n">g</span><span class="p">,</span> 
</span></span><span class="line"><span class="cl">    <span class="n">h_out</span> <span class="o">=</span> <span class="n">o</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">c_out</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">h_out</span><span class="p">,</span> <span class="n">c_out</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="full-sequence-lstm">Full sequence LSTM</h2>
<p>基于PyTorch的传统，在实现LSTM时，返回所有的hidden state以及最后一个cell state。前面没有提到，LSTM中各个cell的参数的权重是共享的。那LSTM实际上就是根据序列的长度重复执行<code>lstm_cell</code>即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">W_hh</span><span class="p">,</span> <span class="n">W_ih</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">H</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
</span></span><span class="line"><span class="cl">        <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">t</span><span class="p">],</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">W_hh</span><span class="p">,</span> <span class="n">W_ih</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">H</span><span class="p">[</span><span class="n">t</span><span class="p">,:]</span> <span class="o">=</span> <span class="n">h</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">H</span><span class="p">,</span> <span class="n">c</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="batching-efficiently">Batching efficiently</h2>
<p>接下来我们考虑如何实现batch LSTM，一种符合习惯的做法是将batch作为第一个维度将输入X堆叠起来，即<code>X[NUM_BATCHES][NUM_TIMESTEPS][INPUT_SIZE]</code>，这种格式被称为NTC格式。如果采用改格式，那么在经过lstm时，第i个cell访问的元素为<code>X[:,i,:]</code>，注意，这些元素在内存中不是紧密排列的，cache命中率较低。</p>
<p>如果将时间维度放在第一个，即采用TNC格式，则能够解决该问题。</p>
<p>其余代码几乎不需要改动，矩阵乘法时要注意将三维的X放到@运算符前面：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">lstm_cell</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">W_hh</span><span class="p">,</span> <span class="n">W_ih</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">i</span><span class="p">,</span><span class="n">f</span><span class="p">,</span><span class="n">g</span><span class="p">,</span><span class="n">o</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="nd">@W_ih</span> <span class="o">+</span> <span class="n">h</span><span class="nd">@W_hh</span> <span class="o">+</span> <span class="n">b</span><span class="p">[</span><span class="kc">None</span><span class="p">,:],</span> <span class="mi">4</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">i</span><span class="p">,</span><span class="n">f</span><span class="p">,</span><span class="n">g</span><span class="p">,</span><span class="n">o</span> <span class="o">=</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">i</span><span class="p">),</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">f</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">g</span><span class="p">),</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">c_out</span> <span class="o">=</span> <span class="n">f</span><span class="o">*</span><span class="n">c</span> <span class="o">+</span> <span class="n">i</span><span class="o">*</span><span class="n">g</span>
</span></span><span class="line"><span class="cl">    <span class="n">h_out</span> <span class="o">=</span> <span class="n">o</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">c_out</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">h_out</span><span class="p">,</span> <span class="n">c_out</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">W_hh</span><span class="p">,</span> <span class="n">W_ih</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">H</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
</span></span><span class="line"><span class="cl">        <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">t</span><span class="p">],</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">W_hh</span><span class="p">,</span> <span class="n">W_ih</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">H</span><span class="p">[</span><span class="n">t</span><span class="p">,:,:]</span> <span class="o">=</span> <span class="n">h</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">H</span><span class="p">,</span> <span class="n">c</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="training-lstms">Training LSTMs</h2>
<p>训练一个单层LSTM很简单，不赘述，直接看代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train_lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span><span class="p">,</span> <span class="n">parameters</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">H</span><span class="p">,</span> <span class="n">cn</span> <span class="o">=</span> <span class="n">lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span><span class="p">,</span> <span class="n">parameters</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">l</span> <span class="o">=</span> <span class="n">loss</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">Y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">l</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>训练一个多层LSTM也不难，可以选择先在深度或者时间维度上正向传播，再在另一个维度上正向传播。示例代码采用先时间再深度的形式：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train_lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span><span class="p">,</span> <span class="n">parameters</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">H</span> <span class="o">=</span> <span class="n">X</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">depth</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">H</span><span class="p">,</span> <span class="n">cn</span> <span class="o">=</span> <span class="n">lstm</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">h0</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">c0</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">parameters</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="n">l</span> <span class="o">=</span> <span class="n">loss</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">Y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">l</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来重头戏来了。如果我们的序列长度很长，那么进行一次正向传播需要保存的中间变量就很多很多，显存可能不够，怎么解决这个问题？</p>
<p>我们可以把这个序列按照某个固定长度进行截断，首先计算第一段中的loss，并进行反向传播，然后对后一段继续进行正向传播，同时将第一段的最后一个cell state作为第二段的初始state传入，然后反向传播&hellip;</p>
<p>一直等到整个序列处理完毕，再更新参数。理解这个过程后，不难发现，阶段版本和完整版本是完全等价的，这也是为什么lstm需要返回最后一个cell state。上述过程可描述为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">train_lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span><span class="p">,</span> <span class="n">parameters</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">H</span><span class="p">,</span> <span class="n">cn</span> <span class="o">=</span> <span class="n">lstm</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span><span class="p">,</span> <span class="n">parameters</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">l</span> <span class="o">=</span> <span class="n">loss</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">Y</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">l</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">H</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">cn</span><span class="o">.</span><span class="n">data</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">h0</span><span class="p">,</span> <span class="n">c0</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="n">BLOCK_SIZE</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span> <span class="o">=</span> <span class="n">train_lstm</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="n">BLOCK_SIZE</span><span class="p">],</span> <span class="n">Y</span><span class="p">[</span><span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="n">BLOCK_SIZE</span><span class="p">],</span> <span class="n">h0</span><span class="p">,</span> <span class="n">c0</span><span class="p">,</span> <span class="n">parameters</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lecture-20-transformers-and-attention">Lecture 20: Transformers and Attention</h1>
<h2 id="两种为时间序列建模的方法-the-two-approaches-to-time-series-modeling">两种为时间序列建模的方法 The two approaches to time series modeling</h2>
<p>RNN在对时间序列建模时采用了一种被称为潜在状态latent state的方式，具体来说，其使用t时刻的hidden state来描述t及t时刻往前的所有信息。这种方法的优点是其理论上可以聚合无限长时刻的信息，缺点是其难以有效记住较远时刻的信息，并且存在梯度爆炸和消失问题。</p>
<p>而另一种建模方式被称为直接预测 direct prediction，具体来说，直接使用t和t时刻之前的sequence来预测t时刻的输出。这种方式的优点时，对于大部分输出，其计算路径要短的，缺点是没有明确的状态表示，在实践中往往序列长度有限。Transformer就属于这种直接预测方式对时间序列进行建模。</p>
<p>【此处跳过对CNN用于时间序列建模及其优缺点的介绍】</p>
<h2 id="自注意力机制和transformer-self-attention-and-transformers">自注意力机制和Transformer Self-attention and transformers</h2>
<p>Attention机制本质上指的是任何对状态进行加权求和的机制，这个权重显然不应该由我们自己决定，而是可学习的参数，再经过一层softmax后得到的权重。</p>
<p>而自注意力机制，顾名思义，就是由状态自己来决定权重，然后对状态按权重求和的机制。</p>
<p>在自注意力中，KQV是三个shape相同的矩阵，即$K,Q,V\in \mathbb{R}^{T\times d}$ ，KQV都是由输入$X$乘上不同的权重得到的$W_K W_Q W_V$得到，self-attention算子的定义为：


<div>$$

\text{SelfAttention}(K,Q,V) = \text{softmax}(\frac{KQ^T}{\sqrt{d}})V

$$</div>

其中，softmax操作是对每一行进行的。</p>
<p>接下来我们尝试理解这个式子在做什么。首先我们要明确，KQV的每一行都是由X对应行加权求和得到的，也就是说，KQV每一行并没有其它行的时序信息（X的每一行表示一个时间的输入）。$KQ^T$是一个T×T的矩阵，其第i行第j个元素是由K的第i行和j的第i列作内积得到，在这里，时序信息进行了交换。对于$KQ^T$的第i行，其每个元素的值大小在一定程度上反应了Q中每一列与之的相似度，然后对这一行进行了softmax操作，得到权重。接下来将这个权重矩阵乘上V，得到自注意力的值。最后得到的结果矩阵中，每一行结果都是根据权重矩阵对V进行加权求和得到的，也正是在这里，发生了时序信息的混合。</p>
<p>自注意力有如下几个特点</p>
<ul>
<li>对KQV的排列具有不变性（实际上是等价性）。也就是说，如果按行重排KQV，自注意力的结果不会因此改变，只会因此发生对应的重排。</li>
<li>自注意力机制会在所有的时间步上起作用，也就是自注意力可以混合时序信息。</li>
<li>计算开销为$O(T^2d)$。</li>
</ul>
<p>一个Transformer Block结构如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409080919833.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>这个流程用公式表示为：


<div>$$

\begin{align*}  
\tilde{Z} &amp;:= \text{SelfAttention}\big(Z^{(i)}W_K,Z^{(i)}W_Q,Z^{(i)}W_V\big) \\  
&amp;= \mathrm{softmax}\left(\frac{Z^{(i)}W_KW_V^T(Z^{(i)})^T}{d^{1/2}}\right)Z^{(i)}W_V \\  
\tilde{Z} &amp;:= \text{LayerNorm}\bigg(Z^{(i)} \boldsymbol{&#43;}\tilde{Z}\bigg) \\  
Z^{(i&#43;1)} &amp;:= \text{LayerNorm}(\mathrm{ReLU}(\tilde{Z}W)&#43;\tilde{Z})  
\end{align*}

$$</div>
</p>
<p>Transformer的优点是：</p>
<ul>
<li>可以在一个block中混合所有时间步的时序信息；</li>
<li>随着时间步的增加，Transformer不需要额外引入新的参数。</li>
</ul>
<p>其缺点是：</p>
<ul>
<li>每个输出都依赖于所有时间步的输入；</li>
<li>输入没有时序，也就是说可以将时序打乱再输入给Transformer，结果还是一样的。</li>
</ul>
<p>接下来介绍两种技术针对缺点进行改进。</p>
<p>首先是掩码自注意力，即masked self-attention。之前提到，在自注意力的计算公式中，softmax后的$KQ^T$是一个密集矩阵，每一行都是表示一个权重，会将所有时刻的状态加权求和。而掩码自注意力的做法是，将让$KQ^T$的上三角部分减去无穷大，这样权重矩阵的上三角部分为0，即只对t之前的时刻加权求和，以防止获取未来信息。</p>
<p>为了解决输入时序的问题，引入了位置编码position encoding技术，给输入加上一个用于表示时间信息的矩阵，如下所示：


<div>$$

X\in\mathbb{R}^n= \begin{bmatrix}-&amp;x_1^\top&amp;-\\-&amp;x_2^\top&amp;-\\&amp;\vdots&amp;\\-&amp;x_T^\top&amp;-\end{bmatrix}&#43;\begin{bmatrix}\sin(\omega_1\cdot1)&amp;\cdots&amp;\sin(\omega_n\cdot1)\\\sin(\omega_1\cdot2)&amp;\cdots&amp;\sin(\omega_n\cdot2)\\\vdots&amp;\ddots&amp;\vdots\\\sin(\omega_1\cdot T)&amp;\cdots&amp;\sin(\omega_n\cdot T)\end{bmatrix}

$$</div>

通常，其中的$w_i$根据对数函数的变化趋势来选择。</p>
<h1 id="lecture-21-transformer-implementation">Lecture 21: Transformer Implementation</h1>
<p>本节课中，我们将使用NumPy来实现Transformer。</p>
<h2 id="自注意力机制-self-attention">自注意力机制 Self-attention</h2>
<p>自注意力的公式为：


<div>$$

Y = \left(\mathrm{softmax}\left(\frac{X W_K W_Q^T X^T}{\sqrt{d}}\right)X W_V \right) W_o

$$</div>

与上一讲有些许不同之处在于在输出前进行了一次额外的线性变换。</p>
<p>注意到公式中我们需要将X与三个W分别相乘以得到KQV，可以将这三次矩阵运算变为一个运算，即将三个矩阵concat在一起，然后与X相乘，一下子得到concat在一起的KQV。一个自注意力模块为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">self_attention</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">W_KQV</span><span class="p">,</span> <span class="n">W_out</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">K</span><span class="p">,</span><span class="n">Q</span><span class="p">,</span><span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">X</span><span class="nd">@W_KQV</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">attn</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">K</span><span class="nd">@Q.swapaxes</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="o">+</span> <span class="n">mask</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">attn</span><span class="nd">@V@W_out</span><span class="p">,</span> <span class="n">attn</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="minibatching-with-batch-matrix-multiply">Minibatching with batch matrix multiply</h2>
<p>自注意力不是按照时间循序进行前向传播的，因此X仍按照正常BTD的顺序在内存中组织。当我们实现批量self-attention时，就涉及到了批量矩阵乘法的概念。</p>
<p>具体来说，公式中$K@Q^T$这一步的矩阵乘法，K和Q的shape都是B×T×T，这就涉及到了批量矩阵乘法。对我们都自注意力来说，想要的应该是K[i,:,:]与Q.T[i,:,:]相乘，碰巧，批量矩阵乘法正是这么定义的，也就是说，批量矩阵乘法要求两个矩阵之间除了倒数两个维度符合矩阵乘法要求，剩余其它维度要么不存在或为1进行广播，要么就是要相等的。</p>
<h2 id="multihead-attention-多头注意力">Multihead attention 多头注意力</h2>
<p>多头自注意力的动机来自于$K@Q^T$这一步，结果每个元素都是长度为d的两个向量内积得到的。为了降低计算成本，提出了一种多头注意力机制。即，将KQV的每一行分为h个部分，进行注意力操作，然后再拼接起来。这样，$K@Q^T$每个值都是长度为d/h向量进行内积得到的。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">multihead_attention</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">W_KQV</span><span class="p">,</span> <span class="n">W_out</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">N</span><span class="p">,</span><span class="n">T</span><span class="p">,</span><span class="n">d</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span>
</span></span><span class="line"><span class="cl">    <span class="n">K</span><span class="p">,</span><span class="n">Q</span><span class="p">,</span><span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">X</span><span class="nd">@W_KQV</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">K</span><span class="p">,</span><span class="n">Q</span><span class="p">,</span><span class="n">V</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">N</span><span class="p">,</span><span class="n">T</span><span class="p">,</span><span class="n">heads</span><span class="p">,</span><span class="n">d</span><span class="o">//</span><span class="n">heads</span><span class="p">)</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">)</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="p">(</span><span class="n">K</span><span class="p">,</span><span class="n">Q</span><span class="p">,</span><span class="n">V</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="n">attn</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">K</span><span class="nd">@Q.swapaxes</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d</span><span class="o">//</span><span class="n">heads</span><span class="p">)</span> <span class="o">+</span> <span class="n">mask</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="p">(</span><span class="n">attn</span><span class="nd">@V</span><span class="p">)</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">N</span><span class="p">,</span><span class="n">T</span><span class="p">,</span><span class="n">d</span><span class="p">)</span> <span class="o">@</span> <span class="n">W_out</span><span class="p">,</span> <span class="n">attn</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="transformer-block">Transformer block</h2>
<p>一个Transformer Block结构如下图所示：
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409080919833.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>应用已经实现的各个组件，我们可以轻松地写出一个支持多头自注意力的Transformer块：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">layer_norm</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="p">(</span><span class="n">Z</span> <span class="o">-</span> <span class="n">Z</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">Z</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">relu</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">transformer</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">W_KQV</span><span class="p">,</span> <span class="n">W_out</span><span class="p">,</span> <span class="n">W_ff1</span><span class="p">,</span> <span class="n">W_ff2</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">Z</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">multihead_attention</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">W_KQV</span><span class="p">,</span> <span class="n">W_out</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">X</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">relu</span><span class="p">(</span><span class="n">Z</span><span class="nd">@W_ff1</span><span class="p">)</span><span class="nd">@W_ff2</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lecture-23-moel-deployment">Lecture 23 Moel Deployment</h1>
<h2 id="模型部署概览-model-deployment-overview">模型部署概览 Model deployment overview</h2>
<p>在特定的设备上部署训练好的模型是一件比较麻烦的事情，其受设备的影响很大。现在有一些用于部署推理的框架，例如NVIDIA设备上的TensorRT，在嵌入式设备上有ARMComputeLib和TFLite，苹果有CoreML。</p>
<p>上述框架都需要一种推理模型格式的输入，这个输入能够描述模型的计算流程，这种格式目前有ONNX、CoreML和TFLite。模型通过Python编写，其好处是提高了编码效率，带来的缺点就是可能某些模型没办法完美转换为上述通用格式。</p>
<p>许多推理框架都是以计算图解释器的形式组织的，其通过预分配和重用内存、算子融合、精度量化等优化手段，实现更高效的推理。但同样，其也有很多限制，例如他们支持的算子类别是有限的。</p>
<h2 id="机器学习编译-machine-learning-compilation">机器学习编译 Machine learning compilation</h2>
<p>机器学习编译试图打破需要为每种设备定制推理库的现状，其目标是将输入的深度学习模型转换为可以直接在终端上运行的代码。</p>
<p>一个ML程序可以被称为一个模块，这个模块由多个函数构成，函数之间互相调用。下图这种格式被称为中间状态表示IR，下图这一模块被称为IR模块。</p>
<p><img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409081958517.png?x-oss-process=image/quality,q_90/format,webp"></p>
<p>ML编译的流程大致有：</p>
<ul>
<li>
<p>从深度学习框架中导入模型；
<img loading="lazy" src="https://pics.zhouxin.space/202409082011257.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>对IR模块进行变换，算子融合
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409082011947.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>将中间状态翻译为更低级的循环代码
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409082011782.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>进行更低级的变换，进行算子融合
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409082013131.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>进行代码生成
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/202409082014722.png?x-oss-process=image/quality,q_90/format,webp">
本讲后续内容和下一讲均为介绍MLC，计划后面继续学习MLC，这里就不浅尝辄止了。</p>
</li>
</ul>
<p>全文完。</p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://blog.csdn.net/qq_36892712/article/details/133774755">指数移动平均EMA_ema移动平均数怎么算-CSDN博客</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p><a href="https://zhuanlan.zhihu.com/p/22810533">zhuanlan.zhihu.com/p/22810533</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p><a href="https://www.ruder.io/optimizing-gradient-descent/#Nesterov%20accelerated%20gradient">An overview of gradient descent optimization algorithms</a>&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p><a href="https://numpy.org/doc/stable/reference/generated/numpy.lib.stride_tricks.as_strided.html">numpy.lib.stride_tricks.as_strided — NumPy v2.1 Manual</a>&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>科目一考试知识点</title>
      <link>https://www.zhouxin.space/wiki/tips-on-driver-lisence-subject-one/</link>
      <pubDate>Mon, 29 Apr 2024 16:00:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/wiki/tips-on-driver-lisence-subject-one/</guid>
      <description>&lt;h1 id=&#34;考证相关常识&#34;&gt;考证相关常识&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;小型客车：18 周岁 +，可初次申领&lt;/li&gt;
&lt;li&gt;大型货车：20 周岁 +，可初次申领&lt;/li&gt;
&lt;li&gt;中型客车、大型货车：20 周岁 +，货车可初次申领，客车要增驾&lt;/li&gt;
&lt;li&gt;大型客车、重型牵引挂车：22 周岁 +，只可增驾&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;增驾&#34;&gt;增驾&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;大型货车增驾重型牵引挂车和中型客车，两年且两年内无 12 分扣满&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;驾驶证相关常识&#34;&gt;驾驶证相关常识&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;到期前&lt;strong&gt;90&lt;/strong&gt;日内，向&lt;strong&gt;驾驶证核发地&lt;/strong&gt;或以外车管所申请换证，否则会被注销。&lt;/li&gt;
&lt;li&gt;驾驶人记分未达到满分，有罚款尚未缴纳的，记分&lt;strong&gt;转入下一记分周期&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;实习期记满 12 分，&lt;strong&gt;注销&lt;/strong&gt;准驾车型资格。&lt;/li&gt;
&lt;li&gt;实习期上高速，应当有&lt;strong&gt;3 年&lt;/strong&gt;以上驾驶人陪同。&lt;/li&gt;
&lt;li&gt;一次有两个以上违法行为记分的，应分别计算&lt;strong&gt;累加&lt;/strong&gt;分值&lt;/li&gt;
&lt;li&gt;小型汽车驾驶人发生交通事故造成人员死亡，承担同等以上责任未被吊销驾驶证的，记分周期结束 30 天内要审验。&lt;/li&gt;
&lt;li&gt;一个周期累计记满 36 分或者三次记满 12 分，要重新参加科目二和科目三&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;禁止考证&#34;&gt;禁止考证&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;考试过程作弊（未取得），&lt;strong&gt;二千元以下&lt;/strong&gt;罚款，&lt;strong&gt;一年&lt;/strong&gt;内不得再考&lt;/li&gt;
&lt;li&gt;作弊取得了驾驶证，&lt;strong&gt;三年&lt;/strong&gt;内不得再考&lt;/li&gt;
&lt;li&gt;无证驾驶造成重伤或者死亡，&lt;strong&gt;十年&lt;/strong&gt;不得考证&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;车辆代号&#34;&gt;车辆代号&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;A1：大型客车&lt;/li&gt;
&lt;li&gt;A2：牵引车&lt;/li&gt;
&lt;li&gt;A3：城市公交车&lt;/li&gt;
&lt;li&gt;B1：中型客车&lt;/li&gt;
&lt;li&gt;B2：大型货车&lt;/li&gt;
&lt;li&gt;C1：小型汽车&lt;/li&gt;
&lt;li&gt;C2：自动挡汽车&lt;/li&gt;
&lt;li&gt;C3：低速载货汽车&lt;/li&gt;
&lt;li&gt;C4：三轮汽车&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;满分学习&#34;&gt;满分学习&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;现场 + 网络不少于五天&lt;/li&gt;
&lt;li&gt;现场不少于两天&lt;/li&gt;
&lt;li&gt;每日不少于三小时&lt;/li&gt;
&lt;li&gt;一个积分周期内两次记满 12 分，理论考试合格后重新参加道路考试&lt;/li&gt;
&lt;li&gt;30日内拒绝参加，公告其驾驶证停止使用&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;知识教育&#34;&gt;知识教育&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;机动车驾驶人发生人身伤亡交通事故负有同等以上责任，参加为期两天的学习&lt;/li&gt;
&lt;li&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;记分扣免&#34;&gt;记分扣免&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;现场学习一小时切考试合格，减 2 分&lt;/li&gt;
&lt;li&gt;弄虚作假罚款 1000 以下&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;赔偿责任&#34;&gt;赔偿责任&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;行人故意碰撞机动车，机动车无需担责。&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;罚款扣分相关&#34;&gt;罚款扣分相关&lt;/h1&gt;
&lt;h2 id=&#34;扣分&#34;&gt;扣分&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;一分
&lt;ul&gt;
&lt;li&gt;不按规定使用灯光&lt;/li&gt;
&lt;li&gt;没带驾驶证&lt;/li&gt;
&lt;li&gt;不按规定会车&lt;/li&gt;
&lt;li&gt;违反禁令标志、禁止标线&lt;/li&gt;
&lt;li&gt;驾驶未按规定定期进行安全技术检验的非特殊车辆&lt;/li&gt;
&lt;li&gt;没系安全带&lt;/li&gt;
&lt;li&gt;非高速掉头、倒车&lt;/li&gt;
&lt;li&gt;载货汽车超重 30% 以下&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;三分
&lt;ul&gt;
&lt;li&gt;不避让校车&lt;/li&gt;
&lt;li&gt;高速公路上行驶低于规定最低时速的记 3 分&lt;/li&gt;
&lt;li&gt;不按规定超车、让行&lt;/li&gt;
&lt;li&gt;在非高速逆行&lt;/li&gt;
&lt;li&gt;校车、客运汽车超载 20% 以下&lt;/li&gt;
&lt;li&gt;不按规定安装车牌&lt;/li&gt;
&lt;li&gt;发生故障不按规定使用灯光和警告标志&lt;/li&gt;
&lt;li&gt;高速上不按规定车道行驶&lt;/li&gt;
&lt;li&gt;普通车超载20-50%&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;六分
&lt;ul&gt;
&lt;li&gt;造成致人轻微伤或者财产损失的交通事故后逃逸&lt;/li&gt;
&lt;li&gt;普通车辆高速超速 20%~50%&lt;/li&gt;
&lt;li&gt;运载爆炸物品未标识、未按照指定路线行驶&lt;/li&gt;
&lt;li&gt;普通汽车超载 50%-100%&lt;/li&gt;
&lt;li&gt;驾驶证被扣期间驾车&lt;/li&gt;
&lt;li&gt;普通汽车在普通道路超速 50%&lt;/li&gt;
&lt;li&gt;载货汽车超重 50% 以上&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;九分
&lt;ul&gt;
&lt;li&gt;未悬挂机动车号牌或者故意遮挡、污损机动车号牌&lt;/li&gt;
&lt;li&gt;驾驶与准驾车型不符的汽车&lt;/li&gt;
&lt;li&gt;七座以上汽车超载50%-100%&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;十二分
&lt;ul&gt;
&lt;li&gt;造成致人轻伤以上或者死亡的交通事故后逃逸&lt;/li&gt;
&lt;li&gt;普通车高速超速 50%&lt;/li&gt;
&lt;li&gt;普通车超载100%以上&lt;/li&gt;
&lt;li&gt;校车超载20%以上&lt;/li&gt;
&lt;li&gt;中型以上客车高速超速20%&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;罚款&#34;&gt;罚款&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;考试过程作弊（未取得），&lt;strong&gt;二千元以下&lt;/strong&gt;罚款，&lt;strong&gt;一年&lt;/strong&gt;内不得再考。&lt;/li&gt;
&lt;li&gt;逾期不参加审验仍然驾驶机动车，罚 200-500。&lt;/li&gt;
&lt;li&gt;超过驾驶证有效期驾驶，200-1000原罚款&lt;/li&gt;
&lt;li&gt;酒后构成重大事故犯罪，吊销驾驶证，终生不得再申请&lt;/li&gt;
&lt;li&gt;不按规定停放150&lt;/li&gt;
&lt;li&gt;与准驾车型不符200-1000&lt;/li&gt;
&lt;li&gt;记分满12分仍驾驶200-1000&lt;/li&gt;
&lt;li&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;刑事责任&#34;&gt;刑事责任&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;重伤、死亡：三年以下&lt;/li&gt;
&lt;li&gt;死亡且逃逸：三年以上，七年以下&lt;/li&gt;
&lt;li&gt;因逃逸而死亡：七年以上&lt;/li&gt;
&lt;li&gt;追逐竞驶：拘役和罚金&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;车辆速度与距离&#34;&gt;车辆速度与距离&lt;/h1&gt;
&lt;h2 id=&#34;车辆速度&#34;&gt;车辆速度&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;公路无中心线 40，有中心线 70，双车道为90&lt;/li&gt;
&lt;li&gt;城市无中心线 30，有中心线 50，双车道为60&lt;/li&gt;
&lt;li&gt;单位院内20&lt;/li&gt;
&lt;li&gt;限速 30：通过铁路口、急弯路、掉头、下坡、能见度 50 米、冰雪泥泞、牵引其他车辆&lt;/li&gt;
&lt;li&gt;两道高速：左侧 100-120，右侧 60-120&lt;/li&gt;
&lt;li&gt;三道高速：左侧 110-120，中间 90-120，右侧 60-120&lt;/li&gt;
&lt;li&gt;能见度小于200m：50速度，100距离&lt;/li&gt;
&lt;li&gt;能见度小于100米：40速度，50距离&lt;/li&gt;
&lt;li&gt;能见度小于50米：20速度，驶出高速路&lt;/li&gt;
&lt;li&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;车辆距离&#34;&gt;车辆距离&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;高速路车速大于等于 100，安全距离 100m+&lt;/li&gt;
&lt;li&gt;高速路车速小于 100，安全距离 50m+&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;停车距离&#34;&gt;停车距离&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;50米：交叉路口、铁路口、急弯路、窄路、隧道、陡坡、桥梁&lt;/li&gt;
&lt;li&gt;30米：公交站、急救站、加油站、消防栓&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;灯光使用&#34;&gt;灯光使用&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;夜晚会车150m外将远光灯改为近光灯&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;其他&#34;&gt;其他&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;收到事故认定书10日内提出书面调解申请&lt;/li&gt;
&lt;li&gt;现场未报警，事后要求处理，应当在10日内提供证据&lt;/li&gt;
&lt;li&gt;自适应巡航：Adaptive Cruise Control，ACC&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id=&#34;相关标志&#34;&gt;相关标志&lt;/h1&gt;
&lt;h2 id=&#34;警告&#34;&gt;警告&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;注意潮汐车道&lt;br&gt;
&lt;img alt=&#34;image.png&#34; loading=&#34;lazy&#34; src=&#34;https://pics.zhouxin.space/20240429163504.png?x-oss-process=image/quality,q_90/format,webp&#34;&gt;&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="考证相关常识">考证相关常识</h1>
<ul>
<li>小型客车：18 周岁 +，可初次申领</li>
<li>大型货车：20 周岁 +，可初次申领</li>
<li>中型客车、大型货车：20 周岁 +，货车可初次申领，客车要增驾</li>
<li>大型客车、重型牵引挂车：22 周岁 +，只可增驾</li>
</ul>
<h2 id="增驾">增驾</h2>
<ul>
<li>大型货车增驾重型牵引挂车和中型客车，两年且两年内无 12 分扣满</li>
</ul>
<h1 id="驾驶证相关常识">驾驶证相关常识</h1>
<ul>
<li>到期前<strong>90</strong>日内，向<strong>驾驶证核发地</strong>或以外车管所申请换证，否则会被注销。</li>
<li>驾驶人记分未达到满分，有罚款尚未缴纳的，记分<strong>转入下一记分周期</strong>。</li>
<li>实习期记满 12 分，<strong>注销</strong>准驾车型资格。</li>
<li>实习期上高速，应当有<strong>3 年</strong>以上驾驶人陪同。</li>
<li>一次有两个以上违法行为记分的，应分别计算<strong>累加</strong>分值</li>
<li>小型汽车驾驶人发生交通事故造成人员死亡，承担同等以上责任未被吊销驾驶证的，记分周期结束 30 天内要审验。</li>
<li>一个周期累计记满 36 分或者三次记满 12 分，要重新参加科目二和科目三</li>
</ul>
<h2 id="禁止考证">禁止考证</h2>
<ul>
<li>考试过程作弊（未取得），<strong>二千元以下</strong>罚款，<strong>一年</strong>内不得再考</li>
<li>作弊取得了驾驶证，<strong>三年</strong>内不得再考</li>
<li>无证驾驶造成重伤或者死亡，<strong>十年</strong>不得考证</li>
</ul>
<h2 id="车辆代号">车辆代号</h2>
<ul>
<li>A1：大型客车</li>
<li>A2：牵引车</li>
<li>A3：城市公交车</li>
<li>B1：中型客车</li>
<li>B2：大型货车</li>
<li>C1：小型汽车</li>
<li>C2：自动挡汽车</li>
<li>C3：低速载货汽车</li>
<li>C4：三轮汽车</li>
</ul>
<h2 id="满分学习">满分学习</h2>
<ul>
<li>现场 + 网络不少于五天</li>
<li>现场不少于两天</li>
<li>每日不少于三小时</li>
<li>一个积分周期内两次记满 12 分，理论考试合格后重新参加道路考试</li>
<li>30日内拒绝参加，公告其驾驶证停止使用</li>
</ul>
<h2 id="知识教育">知识教育</h2>
<ul>
<li>机动车驾驶人发生人身伤亡交通事故负有同等以上责任，参加为期两天的学习</li>
<li></li>
</ul>
<h2 id="记分扣免">记分扣免</h2>
<ul>
<li>现场学习一小时切考试合格，减 2 分</li>
<li>弄虚作假罚款 1000 以下</li>
</ul>
<h1 id="赔偿责任">赔偿责任</h1>
<ul>
<li>行人故意碰撞机动车，机动车无需担责。</li>
</ul>
<h1 id="罚款扣分相关">罚款扣分相关</h1>
<h2 id="扣分">扣分</h2>
<ul>
<li>一分
<ul>
<li>不按规定使用灯光</li>
<li>没带驾驶证</li>
<li>不按规定会车</li>
<li>违反禁令标志、禁止标线</li>
<li>驾驶未按规定定期进行安全技术检验的非特殊车辆</li>
<li>没系安全带</li>
<li>非高速掉头、倒车</li>
<li>载货汽车超重 30% 以下</li>
</ul>
</li>
<li>三分
<ul>
<li>不避让校车</li>
<li>高速公路上行驶低于规定最低时速的记 3 分</li>
<li>不按规定超车、让行</li>
<li>在非高速逆行</li>
<li>校车、客运汽车超载 20% 以下</li>
<li>不按规定安装车牌</li>
<li>发生故障不按规定使用灯光和警告标志</li>
<li>高速上不按规定车道行驶</li>
<li>普通车超载20-50%</li>
</ul>
</li>
<li>六分
<ul>
<li>造成致人轻微伤或者财产损失的交通事故后逃逸</li>
<li>普通车辆高速超速 20%~50%</li>
<li>运载爆炸物品未标识、未按照指定路线行驶</li>
<li>普通汽车超载 50%-100%</li>
<li>驾驶证被扣期间驾车</li>
<li>普通汽车在普通道路超速 50%</li>
<li>载货汽车超重 50% 以上</li>
</ul>
</li>
<li>九分
<ul>
<li>未悬挂机动车号牌或者故意遮挡、污损机动车号牌</li>
<li>驾驶与准驾车型不符的汽车</li>
<li>七座以上汽车超载50%-100%</li>
</ul>
</li>
<li>十二分
<ul>
<li>造成致人轻伤以上或者死亡的交通事故后逃逸</li>
<li>普通车高速超速 50%</li>
<li>普通车超载100%以上</li>
<li>校车超载20%以上</li>
<li>中型以上客车高速超速20%</li>
</ul>
</li>
</ul>
<h2 id="罚款">罚款</h2>
<ul>
<li>考试过程作弊（未取得），<strong>二千元以下</strong>罚款，<strong>一年</strong>内不得再考。</li>
<li>逾期不参加审验仍然驾驶机动车，罚 200-500。</li>
<li>超过驾驶证有效期驾驶，200-1000原罚款</li>
<li>酒后构成重大事故犯罪，吊销驾驶证，终生不得再申请</li>
<li>不按规定停放150</li>
<li>与准驾车型不符200-1000</li>
<li>记分满12分仍驾驶200-1000</li>
<li></li>
</ul>
<h2 id="刑事责任">刑事责任</h2>
<ul>
<li>重伤、死亡：三年以下</li>
<li>死亡且逃逸：三年以上，七年以下</li>
<li>因逃逸而死亡：七年以上</li>
<li>追逐竞驶：拘役和罚金</li>
</ul>
<h1 id="车辆速度与距离">车辆速度与距离</h1>
<h2 id="车辆速度">车辆速度</h2>
<ul>
<li>公路无中心线 40，有中心线 70，双车道为90</li>
<li>城市无中心线 30，有中心线 50，双车道为60</li>
<li>单位院内20</li>
<li>限速 30：通过铁路口、急弯路、掉头、下坡、能见度 50 米、冰雪泥泞、牵引其他车辆</li>
<li>两道高速：左侧 100-120，右侧 60-120</li>
<li>三道高速：左侧 110-120，中间 90-120，右侧 60-120</li>
<li>能见度小于200m：50速度，100距离</li>
<li>能见度小于100米：40速度，50距离</li>
<li>能见度小于50米：20速度，驶出高速路</li>
<li></li>
</ul>
<h2 id="车辆距离">车辆距离</h2>
<ul>
<li>高速路车速大于等于 100，安全距离 100m+</li>
<li>高速路车速小于 100，安全距离 50m+</li>
</ul>
<h2 id="停车距离">停车距离</h2>
<ul>
<li>50米：交叉路口、铁路口、急弯路、窄路、隧道、陡坡、桥梁</li>
<li>30米：公交站、急救站、加油站、消防栓</li>
</ul>
<h1 id="灯光使用">灯光使用</h1>
<ul>
<li>夜晚会车150m外将远光灯改为近光灯</li>
</ul>
<h1 id="其他">其他</h1>
<ul>
<li>收到事故认定书10日内提出书面调解申请</li>
<li>现场未报警，事后要求处理，应当在10日内提供证据</li>
<li>自适应巡航：Adaptive Cruise Control，ACC</li>
</ul>
<h1 id="相关标志">相关标志</h1>
<h2 id="警告">警告</h2>
<ul>
<li>
<p>注意潮汐车道<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429163504.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>注意儿童<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429170305.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>傍山险路<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429170349.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
<li>
<p>注意分离式道路<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429170427.png?x-oss-process=image/quality,q_90/format,webp"></p>
</li>
</ul>
<h2 id="标志">标志</h2>
<ul>
<li>禁止通行（行人和车辆）<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429170511.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>交叉路口预告<br>
<img loading="lazy" src="https://pics.zhouxin.space/20240429170642.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>隧道出口距离<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429170715.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>高速公路停车区预告<br>
<img loading="lazy" src="https://pics.zhouxin.space/20240429170800.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>硬路肩允许行驶路段即将结束（C）
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429173557.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>涵洞
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429175001.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<h2 id="标线">标线</h2>
<ul>
<li>导向车道线<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429163618.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>路口导向线<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429170912.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>中心圈<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429171212.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>接近障碍物标线<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429171251.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>立面标记<br>
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429171327.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
<h2 id="仪表盘">仪表盘</h2>
<ul>
<li>冷却液不足<br>
<img loading="lazy" src="https://pics.zhouxin.space/20240429163217.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>充电电路故障
<img loading="lazy" src="https://pics.zhouxin.space/20240429172151.png?x-oss-process=image/quality,q_90/format,webp"></li>
<li>制动系统出现异常
<img alt="image.png" loading="lazy" src="https://pics.zhouxin.space/20240429174651.png?x-oss-process=image/quality,q_90/format,webp"></li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>Effective Cpp 第三版学习笔记</title>
      <link>https://www.zhouxin.space/notes/notes-on-effective-cpp-3rd-ed/</link>
      <pubDate>Wed, 17 Apr 2024 18:23:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/notes-on-effective-cpp-3rd-ed/</guid>
      <description>&lt;h1 id=&#34;前言&#34;&gt;前言&lt;/h1&gt;
&lt;p&gt;本文是我在学习 Scott Meyers 的著作《Effective C++》第三版的笔记，鉴于豆瓣对于本书中文翻译褒贬不一，我直接看的英文原著。PDF 链接：&lt;a href=&#34;https://github.com/GunterMueller/Books-3/blob/master/Effective%20C%2B%2B%203rd%20ed.pdf&#34;&gt;Books-3/Effective C++ 3rd ed.pdf at master · GunterMueller/Books-3 · GitHub&lt;/a&gt;&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="前言">前言</h1>
<p>本文是我在学习 Scott Meyers 的著作《Effective C++》第三版的笔记，鉴于豆瓣对于本书中文翻译褒贬不一，我直接看的英文原著。PDF 链接：<a href="https://github.com/GunterMueller/Books-3/blob/master/Effective%20C%2B%2B%203rd%20ed.pdf">Books-3/Effective C++ 3rd ed.pdf at master · GunterMueller/Books-3 · GitHub</a></p>
<blockquote>
<p>For Nancy, without whom nothing would be much worth doing.<br>
Wisdom and beauty form a very rare combination.</p>
</blockquote>
<p>作者在扉页留下了这句话，打破了我内心对程序员的一点点小刻板印象，我本以为这类话只会在文青写的书里出现呢。</p>
<h1 id="正文">正文</h1>
<h2 id="introduction">Introduction</h2>
<h3 id="lhs--rhs">lhs &amp; rhs</h3>
<blockquote>
<p>Two of my favorite parameter names, for example, are lhs and rhs. They stand for “left-hand side” and “right- hand side,” respectively.</p>
</blockquote>
<p>lhs 和 rhs 是作者常用的二元运算符的形参命名，分别表示运算符左边和右边的参数。</p>
<h2 id="accustoming-yourself-to-c">Accustoming yourself to C++</h2>
<h3 id="item-1-view-c-as-a-federation-of-languages">Item 1: View C++ as a federation of languages.</h3>
<blockquote>
<p>✦ Rules for effective C++ programming vary, depending on the part of C++ you are using.</p>
</blockquote>
<p>C++ 诞生于 C，但很快成长为一门“联合”的语言。它支持多种编程范式，包括面向过程、面向对象、泛型、函数式和元编程。</p>
<p>GPT 对以上几种编程范式的介绍：</p>
<blockquote>
<ol>
<li>面向过程编程 (Procedural Programming)<br>
面向过程编程是一种编程范式，它侧重于编写一系列计算步骤或过程的方法。这种范式通过函数或过程来操作数据，强调的是程序的动作而非数据的表现形式。这种方式适合于简单的任务，能够让程序员更直接地控制程序流和状态。</li>
</ol>
<p>优点：简单直接，容易理解和实现。<br>
缺点：随着程序规模的扩大，维护和扩展变得困难，难以管理大量的数据和函数。</p>
<ol start="2">
<li>面向对象编程 (Object-Oriented Programming, OOP)<br>
面向对象编程是基于“对象”的概念，其中对象是数据和操作这些数据的函数（称为方法）的集合。OOP 强调将现实世界的实体模型化为程序中的对象，每个对象都有其属性和行为。</li>
</ol>
<p>优点：提高了代码的复用性、封装性和可维护性。<br>
缺点：可能会导致程序结构和实现过于复杂，性能上有时也不如面向过程编程。</p>
<ol start="3">
<li>泛型编程 (Generic Programming)<br>
泛型编程是一种编程风格，通过抽象和重用来增强代码的灵活性和可重用性。它允许程序员编写与数据类型无关的代码。在泛型编程中，数据类型作为参数传递给算法或数据结构。</li>
</ol>
<p>优点：代码复用性高，类型安全。<br>
缺点：可能增加编程的复杂性，理解和设计泛型结构需要更深的抽象思维。</p>
<ol start="4">
<li>函数式编程 (Functional Programming, FP)<br>
函数式编程是一种把计算视为数学函数的求值，并避免使用程序状态及可变数据的编程范式。FP 强调无副作用的函数，支持高阶函数和函数组合，以及懒计算等特性。</li>
</ol>
<p>优点：代码更简洁，更容易推理，便于并行计算。<br>
缺点：与传统的命令式编程思维差异大，学习曲线较陡峭。</p>
<ol start="5">
<li>元编程 (Metaprogramming)<br>
元编程是一种编程技术，允许程序在运行时或编译时改变其自身结构。这通常涉及到代码生成代码的情况。</li>
</ol>
<p>优点：极大的灵活性和动态性。<br>
缺点：可能导致代码难以理解、维护和调试。</p>
<ol start="6">
<li>模板元编程 (Template Metaprogramming)<br>
模板元编程是一种在编译时执行计算的技术，常见于 C++ 的模板机制中。它通过模板实例化时的类型推导来执行算法，从而在编译时生成高度优化的代码。</li>
</ol>
<p>优点：生成高度优化的代码，无运行时开销。</p>
</blockquote>
<p><strong>可以认为 C++ 由四门语言构成，每一门子语言本身都很简单</strong>：</p>
<ul>
<li>C。C++ 兼容 C 的语法，因此使用 C 语言完成的大多数任务都可以用 C++ 来完成，但得益于另外三个 C 不具备的子语言，C++ 可以完成得更得心应手。</li>
<li>Object-Oriented C++。这指的就是我们熟悉的“C with class”，即在 C++ 中引入的面向对象模块。</li>
<li>Template C++。这指的是 C++ 中泛型编程的部分，这还催生了一种全新的编程范式：模板元编程。</li>
<li>STL。STL 对于容器、迭代器、算法和函数对象的实现有其自洽的一套逻辑，如果我们要使用 STL 的内容，那也要遵循这套逻辑。</li>
</ul>
<p>不同子语言之间可能有不同的行为准则，例如 C 的内建类型按值传递相比引用传递更高效，但对于对象而言恰恰相反；又例如 STL 的迭代器行为类似于 C 中的指针，这种情况下又要使用值传递。</p>
<h3 id="item-2-prefer-consts-enums-and-inlines-to-defines">Item 2: Prefer consts, enums, and inlines to defines.</h3>
<blockquote>
<p>✦ For simple constants, prefer const objects or enums to <code>#defines</code>.</p>
<p>✦ For function-like macros, prefer inline functions to <code>#defines</code>.</p>
</blockquote>
<p>这一条可以简写为：尽量让编译器去处理而非在预处理阶段替换。</p>
<p>一个理由是，对于编译器而言，其可能无法得知在预处理阶段被替换的常量符号，因而这些符号不会出现在符号表中。如果这些常量导致了出错或者警告，在错误信息中提示的就是常量的值而非代码中给定的常量名，这降低了错误信息的可读性。</p>
<p>第二个理由是，<code>const</code> 关键字定义的常量可以控制作用域，而 <code>#define</code> 关键字则不可以。</p>
<p>关于把 <code>#define</code> 替换为常量，有几点需要注意：</p>
<ul>
<li>如果需要定义一个指向常量的指针，大部分情况这个这个指针本身也是不可更改指向的，即指向常量的常量指针，需要两个 <code>const</code> 关键字，即：<code>const char* const name = &quot;Name&quot;</code>。</li>
<li>对于类成员是常量的情况，还要声明为静态变量以防止在内存中存在多个常量拷贝，即：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="hl"><span class="lnt">3
</span></span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">GamePlayer</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line hl"><span class="cl">	<span class="k">static</span> <span class="k">const</span>  <span class="kt">int</span> <span class="n">NumTurns</span> <span class="o">=</span> <span class="mi">5</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">scores</span><span class="p">[</span><span class="n">NumTurns</span><span class="p">];</span> 
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>部分很老的编译器可能不允许在类声明中定义静态变量的值，更加通用的做法是在类实现的文件中给出静态成员的值。但有例外：即编译器在编译这个类时就需要知道其静态变量的值，例如上述代码中，编译器需要知道 <code>scores</code> 数组的长度，因此要么在声明时就给出静态变量的值，要么使用曲线救国的方案：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">GamePlayer</span><span class="p">{</span> 
</span></span><span class="line"><span class="cl">	<span class="k">private</span><span class="o">:</span> 
</span></span><span class="line"><span class="cl">		<span class="k">enum</span> <span class="p">{</span> <span class="n">NumTurns</span> <span class="o">=</span> <span class="mi">5</span> <span class="p">};</span>
</span></span><span class="line"><span class="cl">		<span class="kt">int</span> <span class="n">scores</span><span class="p">[</span><span class="n">NumTurns</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">		<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述方案被称为“the enum hack”，了解它的价值在于：</p>
<ul>
<li>the enum hack 相比 <code>const</code> 更像传统的 <code>#define</code>，其不能取地址。</li>
<li>出于实践的考虑：确实有很多代码使用了这个技巧</li>
</ul>
<p>另一个“尽量让编译器去处理而非在预处理阶段替换”的理由是：人们使用宏在不需要函数调用开销的情况下实现类似函数的功能，然而这种宏函数无法执行类型检查并且每个变量都要用括号扩起来。C++ 提供了 <code>inline</code> 关键字用于实现类似的效果，inline 函数会在原地展开，免去了函数调用的开销；同时，其又支持像常规函数一样的语法和类型检查。</p>
<h3 id="item-3-use-const-whenever-possible">Item 3: Use const whenever possible.</h3>
<blockquote>
<p>✦ Declaring something const helps compilers detect usage errors. const can be applied to objects at any scope, to function parameters and return types, and to member functions as a whole.</p>
<p>✦ Compilers enforce bitwise constness, but you should program using logical constness.</p>
<p>✦ When const and non-const member functions have essentially identi- cal implementations, code duplication can be avoided by having the non-const version call the const version.</p>
</blockquote>
<p>尽可能使用 <code>const</code> 关键字，它可以让编译器帮助防止变量被调用者或者其他代码修改。</p>
<p>当 <code>const</code> 关键字和指针相遇，有多种情况：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">char</span> <span class="n">greeting</span><span class="p">[]</span> <span class="o">=</span> <span class="s">&#34;Hello&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">char</span> <span class="o">*</span><span class="n">p</span> <span class="o">=</span> <span class="n">greeting</span><span class="p">;</span> <span class="c1">// non-const pointer, non-const data
</span></span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="kt">char</span> <span class="o">*</span><span class="n">p</span> <span class="o">=</span> <span class="n">greeting</span><span class="p">;</span> <span class="c1">// non-const pointer, const data
</span></span></span><span class="line"><span class="cl"><span class="kt">char</span> <span class="o">*</span> <span class="k">const</span> <span class="n">p</span> <span class="o">=</span> <span class="n">greeting</span><span class="p">;</span> <span class="c1">// const pointer, non-const data
</span></span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="kt">char</span> <span class="o">*</span> <span class="k">const</span> <span class="n">p</span> <span class="o">=</span> <span class="n">greeting</span><span class="p">;</span> <span class="c1">// const pointer, const data
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>上述规则可以总结为：如果 <code>const</code> 出现在 <code>*</code> 的左边，那么指向的数据本身是不可变的；如果 <code>const</code> 出现在 <code>*</code> 的右边，那么指针是不可变的。</p>
<p>对于 <code>const</code> 在 <code>*</code> 的左边的情况，其相对类型的位置又有两种情况，二者是完全等价的，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">const</span> <span class="kt">int</span> <span class="n">a</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="k">const</span> <span class="n">b</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="c1">// a和b均表示一个不可修改的int
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>STL 中的迭代器如果被声明为 <code>const</code>，那么说明这个迭代器本身是不可修改的，而非这个迭代器指向了不可修改的数据。如果需要一个指向不可修改数据的迭代器，需要使用 <code>const_iterator</code> 类型。</p>
<p>在函数声明中，<code>const</code> 关键字可以用来修饰返回值类型、参数类型和整个函数（仅限成员函数）。</p>
<p>通常而言，没有理由将返回值声明为 <code>const</code>，但有的时候这么做也可能减少调用者的错误。例如，假设实现了一个实数类 <code>Rational</code> 并重载了其 <code>operator *</code> 以实现乘法，如果不将返回值声明为 <code>const</code>，那么下列代码就是符合语法但无意义的：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">Rational</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">(</span><span class="n">a</span><span class="o">*</span><span class="n">b</span><span class="p">)</span> <span class="o">=</span> <span class="n">c</span><span class="p">;</span> <span class="c1">// 将c的值赋给临时变量(a*b)
</span></span></span><span class="line"><span class="cl"><span class="k">if</span><span class="p">(</span><span class="n">a</span><span class="o">*</span><span class="n">b</span> <span class="o">=</span> <span class="n">c</span><span class="p">);</span> <span class="c1">// 漏打了一个等号
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>将一个成员函数声明为 <code>const</code> 有助于提高编码效率，一方面它可以帮助调用者区分哪些方法会修改对象哪些不会，另一方面，在使用 const 引用传参的情况下，只能调用该对象的 const 方法。此外，除了声明为 <code>const</code> 之外其他签名均相同的两个成员函数在 C++ 中也被视为重载。</p>
<p>对于 <code>const</code> 有两种哲学理念：</p>
<ul>
<li>bitwise constness：const 成员函数不得修改对象内的任何数据，这是一种比较严格其方便编译器实现的理念，也是 C++ 所采用的。</li>
<li>logic constness: const 成员函数允许以客户无法感知的形式修改对象内的数据，例如私有变量。</li>
</ul>
<p>logic constness 的存在也是合理的。例如，如果我们想实现一个 <code>String</code> 类及其 <code>size()</code> 方法，我们使用一个私有变量 <code>length</code> 缓存其长度，那么将 <code>size()</code> 声明为 <code>const</code> 显然是合理的（否则 <code>const String</code> 将无法获取长度），但在实现 <code>size()</code> 的过程中，第一次访问 <code>size()</code> 不可避免要修改 <code>length</code> 值，这违反了 bitwise constness 理念，但又是符合程序员直觉的一个需求。这种情况下，我们可以使用 <code>mutable</code> 来修饰变量，这样就可以在 const 成员函数中修改他们。</p>
<p>前面提到，const 可以用来重载成员函数，那我们可能会有如下两个重载函数的声明：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Vector</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">Element</span><span class="o">&amp;</span> <span class="k">operator</span> <span class="p">[](</span><span class="n">size_t</span> <span class="n">index</span><span class="p">)</span> <span class="k">const</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span><span class="c1">// 越界检查、身份校验等
</span></span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl">	<span class="n">Element</span><span class="o">&amp;</span> <span class="k">operator</span> <span class="p">[](</span><span class="n">size_t</span> <span class="n">index</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span><span class="c1">// 越界检查、身份校验等
</span></span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>不难发现，const 和非 const 版本的两个 <code>[]</code> 下边访问方法的实现完全相同，但为了让 const 对象可以获取可修改的数据引用和非 const 对象获取不可修改的引用，我们不得不重复两次。</p>
<p>为了减少这种无意义的重复，我们可以在非 const 方法中调用 const 方法，并使用 <code>const_cast</code> 关键字将其转换为非 const 对象。即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Vector</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">Element</span><span class="o">&amp;</span> <span class="k">operator</span> <span class="p">[](</span><span class="n">size_t</span> <span class="n">index</span><span class="p">)</span> <span class="k">const</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span><span class="c1">// 越界检查、身份校验等
</span></span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl">	<span class="n">Element</span><span class="o">&amp;</span> <span class="k">operator</span> <span class="p">[](</span><span class="n">size_t</span> <span class="n">index</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="k">const_cast</span><span class="o">&lt;</span><span class="n">Element</span><span class="o">&amp;&gt;</span><span class="p">(</span> <span class="c1">// 将const element&amp; 转换为 element&amp;
</span></span></span><span class="line"><span class="cl">		<span class="k">static_cast</span><span class="o">&lt;</span><span class="k">const</span> <span class="n">Vector</span><span class="o">&amp;&gt;</span><span class="p">(</span><span class="o">*</span><span class="k">this</span><span class="p">)[</span><span class="n">index</span><span class="p">])</span> <span class="c1">// 将this转换为const对象，以调用const方法
</span></span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-4-make-sure-that-objects-are-initialized-before-theyre-used">Item 4: Make sure that objects are initialized before they’re used.</h3>
<blockquote>
<p>✦ Manually initialize objects of built-in type, because C++ only some- times initializes them itself.</p>
<p>✦ In a constructor, prefer use of the member initialization list to as- signment inside the body of the constructor. List data members in the initialization list in the same order they’re declared in the class.</p>
<p>✦ Avoid initialization order problems across translation units by re- placing non-local static objects with local static objects.</p>
</blockquote>
<p>在 C++ 中，当你定义一个变量时，有一套复杂的规则来决定编译器是否会为你进行默认初始化。然而，试图读取一个未被初始化的变量是一个未定义行为，可能导致程序崩溃或者复杂的 debug。最好的方案是每次在定义时就进行初始化。</p>
<p>对于非成员的内建变量类型，需要手动进行初始化：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">int</span> <span class="n">x</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">text</span> <span class="o">=</span> <span class="s">&#34;Hello World!&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">double</span> <span class="n">d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">cin</span> <span class="o">&gt;&gt;</span> <span class="n">d</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>除此以外几乎所有的情况，初始化的任务由构造函数完成。规则很简单：每一个成员变量都要在构造函数中被初始化。</p>
<p>注意区分构造函数中的初始化和赋值：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="hl"><span class="lnt">12
</span></span><span class="hl"><span class="lnt">13
</span></span><span class="hl"><span class="lnt">14
</span></span><span class="hl"><span class="lnt">15
</span></span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">PhoneNumber</span> <span class="p">{</span> <span class="p">...</span> <span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">ABEntry</span> <span class="p">{</span> <span class="c1">// ABEntry = “Address Book Entry” 
</span></span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span> 
</span></span><span class="line"><span class="cl">	<span class="n">ABEntry</span><span class="p">(</span><span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&amp;</span> <span class="n">name</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&amp;</span> <span class="n">address</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">list</span><span class="o">&lt;</span><span class="n">PhoneNumber</span><span class="o">&gt;&amp;</span> <span class="n">phones</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">theName</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">theAddress</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">list</span><span class="o">&lt;</span><span class="n">PhoneNumber</span><span class="o">&gt;</span> <span class="n">thePhones</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">numTimesConsulted</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="n">ABEntry</span><span class="o">::</span><span class="n">ABEntry</span><span class="p">(</span><span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&amp;</span> <span class="n">name</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&amp;</span> <span class="n">address</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">list</span><span class="o">&lt;</span><span class="n">PhoneNumber</span><span class="o">&gt;&amp;</span> <span class="n">phones</span><span class="p">){</span> 
</span></span><span class="line hl"><span class="cl">	<span class="n">theName</span> <span class="o">=</span> <span class="n">name</span><span class="p">;</span>
</span></span><span class="line hl"><span class="cl">	<span class="n">theAddress</span> <span class="o">=</span> <span class="n">address</span><span class="p">;</span>
</span></span><span class="line hl"><span class="cl">	<span class="n">thePhones</span> <span class="o">=</span> <span class="n">phones</span><span class="p">;</span>
</span></span><span class="line hl"><span class="cl">	<span class="n">numTimesConsulted</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="c1">// 以上都是赋值而非初始化
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>C++ 的类成员的初始化必须在构造函数主体前的初始化列表中完成。上述的赋值方法会先调用各个类的默认构造函数进行隐式初始化，然后再调用拷贝构造函数，使用初始化列表则可以直接调用拷贝构造函数进行初始化，省去了默认构造的时间。此外，内建类型的变量并不会进行默认初始化，必须在初始化列表或者构造函数主体中显式初始化。</p>
<p>类初始化的顺序为：基类先于派生类，类成员按照声明的顺序进行初始化。即便在初始化列表中指定了其它顺序，类内成员仍将按照声明的顺序进行初始化。</p>
<p>接下来讨论静态对象的初始化问题，静态对象包括：全局对象、命名空间中定义的对象、类/函数/文件内被声明为静态的对象。其中，函数内的静态对象被称为局部静态对象，其余被称为非局部静态对象。所有的静态对象在程序结束运行时销毁。</p>
<p>一个<strong>翻译单元</strong>指的是生成一个目标文件的源码，即单个源文件加上其包含的所有头文件。</p>
<p>接下来作者举了一个例子，可以抽象为：一个翻译单元 A 的非静态局部对象的初始化过程引用了来自另一个翻译单元 B 的非局部静态对象，但是编译器并不能保证当 A 初始化时 B 中的非局部静态变量已经初始化。为了解决这个问题，我们可以引入设计模式中的单例模式，在 B 中定义一个全局函数或者在类定义中定义一个成员函数，用于初始化一个局部静态对象并返回其引用。</p>
<p>但是，上述解决方案并不适用于多线程环境：同个线程可能同时初始化一个局部静态对象。可以通过在线程启动前手动调用每个返回局部静态对象的函数以完成初始化。</p>
<h2 id="constructors-destructors-and-assignment-operators">Constructors, Destructors, and Assignment Operators</h2>
<h3 id="item-5-know-what-functions-c-silently-writes-and-calls">Item 5: Know what functions C++ silently writes and calls.</h3>
<blockquote>
<p>✦ Compilers may implicitly generate a class’s default constructor, copy constructor, copy assignment operator, and destructor.</p>
</blockquote>
<p>默认情况下，编译器在<strong>必要时</strong>会生成 public 且 inline 的默认构造函数、析构函数、拷贝构造函数和拷贝赋值函数。编译器为一个类生成的所有函数都是非虚函数，唯一的例外是一个派生类的基类有一个虚析构函数，那么编译器会为其生成一个虚析构函数。否则，将无法通过基类指针/引用销毁派生对象。</p>
<p>生成拷贝构造函数时，编译器会拷贝所有非静态成员。拷贝赋值函数原理与之类似，但是并非所有对象都可以被拷贝，例如私有对象、const 对象或者引用对象，这种情况下编译器就会拒绝生成拷贝构造函数。</p>
<h3 id="item-6-explicitly-disallow-the-use-of-compiler--generated-functions-you-do-not-want">Item 6: Explicitly disallow the use of compiler- generated functions you do not want.</h3>
<blockquote>
<p>✦ Compilers may implicitly generate a class’s default constructor, copy constructor, copy assignment operator, and destructor.</p>
</blockquote>
<p>有些类可能不允许有两个相同的对象，但语法/编译器并没有提供禁用生成拷贝构造和拷贝赋值的关键字。一种可能得实现是，将二者声明为私有的，这可以防止用户调用拷贝构造和赋值；此外，不要实现这两个私有函数，这可以防止友元函数和类成员函数调用。</p>
<p>调用声明但没有定义的函数会在链接期出错，一种将其提前到编译器的办法是，定义一个描述不可拷贝的类 <code>Uncopyable</code>，其它类派生于它：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Uncopyable</span> <span class="p">{</span> 
</span></span><span class="line"><span class="cl"><span class="k">protected</span><span class="o">:</span> <span class="c1">// allow construction and destruction of derived objects... 
</span></span></span><span class="line"><span class="cl">	<span class="n">Uncopyable</span><span class="p">()</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl">	<span class="o">~</span><span class="n">Uncopyable</span><span class="p">()</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span> <span class="c1">// ...but prevent copying 
</span></span></span><span class="line"><span class="cl">	<span class="n">Uncopyable</span><span class="p">(</span><span class="k">const</span> <span class="n">Uncopyable</span><span class="o">&amp;</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl">	<span class="n">Uncopyable</span><span class="o">&amp;</span> <span class="k">operator</span><span class="o">=</span><span class="p">(</span><span class="k">const</span> <span class="n">Uncopyable</span><span class="o">&amp;</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">UncopyableThing</span><span class="o">:</span> <span class="k">private</span> <span class="n">Uncopyable</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>UncopyableThing</code> 中并没有 <code>Uncopyable</code> 对象，但上述方法可以其作用是因为：当编译器尝试生成拷贝函数时，其会调用基类拷贝函数（无论有无该对象）。</p>
<h3 id="item-7-declare-destructors-virtual-in-polymorphic-base-classes">Item 7: Declare destructors virtual in polymorphic base classes.</h3>
<blockquote>
<p>✦ Polymorphic base classes should declare virtual destructors. If a class has any virtual functions, it should have a virtual destructor.</p>
<p>✦ Classes not designed to be base classes or not designed to be used polymorphically should not declare virtual destructors.</p>
</blockquote>
<p>如果我们使用基类指针释放派生对象，并且基类没有虚析构函数，那么会造成 partially destroyed 问题，即派生对象的基类被释放，而其派生部分内存泄漏。</p>
<p>解决这个问题很简单，将基类析构函数声明为虚函数即可。含有虚方法的类大概率都是基类——这些方法都会在派生类中被重写，因此他们的析构函数必须为虚函数。</p>
<p>为不含虚方法的类声明虚析构函数不是个好主意。虚函数在实现时需要额外占用内存（虚函数表指针指向虚函数表），导致原本可以正好装入寄存器的增大一倍（指针长度通常等于机器字长），同时还失去了与 C 语言的兼容性。</p>
<p>值得注意的是，STL 中所有的容器的析构函数都是非虚的，因此不要把他们当做基类（C++11 中引入了 <code>final</code> 关键字）。</p>
<p>如果将析构函数声明为纯虚函数，则必须要在派生类中实现抽象基类的析构函数，这是由于当派生类析构调用结束后，会调用基类的析构函数。</p>
<h3 id="item-8-prevent-exceptions-from-leaving-destructors">Item 8: Prevent exceptions from leaving destructors.</h3>
<blockquote>
<p>✦ Destructors should never emit exceptions. If functions called in a destructor may throw, the destructor should catch any exceptions, then swallow them or terminate the program.</p>
<p>✦ If class clients need to be able to react to exceptions thrown during an operation, the class should provide a regular (i.e., non-destruc- tor) function that performs the operation.</p>
</blockquote>
<p>在析构函数中不应该引发异常，否则就会导致当销毁一个类数组，轮流调用对象的析构函数时，引发多个 active exception，这是未定义的行为，会导致程序终止。</p>
<p>但很多时候析构函数执行的代码（释放资源等）就是会抛出异常，如果在析构过程中捕获异常，有两种方案：</p>
<ul>
<li>使用 <code>std::abort()</code> 终止程序，并记录日志。这可以避免程序出现未定义行为。</li>
<li>继续运行，并记录日志。这可能导致程序异常，毕竟有操作执行失败了。</li>
</ul>
<p>上述两种方案都无法让用户根据异常信息做出反应，可以显式提供一个资源释放的接口，让用户手动释放资源并根据异常做出反应，析构函数同样可以帮用户“擦屁股”释放资源，但如果有异常不能转发给用户，使用前文的两种处理办法之一。</p>
<h3 id="item-9-never-call-virtual-functions-during-construction-or-destruction">Item 9: Never call virtual functions during construction or destruction.</h3>
<blockquote>
<p>✦ Don’t call virtual functions during construction or destruction, be- cause such calls will never go to a more derived class than that of the currently executing constructor or destructor.</p>
</blockquote>
<p>不要在构造函数中调用析构函数。假设有个业务类 <code>Transtraction</code> 及其纯虚成员函数 <code>Transtraction::log</code>，在业务基类中调用这个日志函数，然后根据具体业务派生业务类。如果创建一个具体业务类，其会调用业务基类的日志函数，进而调用日志函数。但是，调用的日志函数<strong>并非</strong>具体业务中的日志，而是 <code>Transtraction::log</code>。这是由于，派生类的构造函数还没执行，其成员都还没进行初始化，因此如果虚函数被绑定在派生类上，那么其对于派生成员的调用都是未定义行为。</p>
<p>事实上，派生类在调用基类构造函数过程中，如果使用 runtime type 技术获取其类型，它不是派生类而是基类。</p>
<p>析构的过程也是如此，当进入基类的析构函数，这个对象类型就将被认为是基类而非派生类。</p>
<p>那怎么实现这个需求呢？将 <code>log</code> 声明为非虚函数，并要求传入 <code>string</code> 类型的日志信息，基类构造函数需要日志信息作为参数，并显式调用 <code>log</code>，派生类构造函数显式调用基类日志信息。这个日志信息可以使用派生类的静态私有函数生成，要求是静态函数是为了防止访问非静态成员（此时派生类成员还没有初始化）。</p>
<h3 id="item-10-have-assignment-operators-return-a-reference-to-this">Item 10: Have assignment operators return a reference to *this.</h3>
<blockquote>
<p>✦ Have assignment operators return a reference to *this.</p>
</blockquote>
<p>在重载赋值运算符时通过返回 <code>*this</code>，可以实现等号传递。这条比较简单，不赘述。</p>
<h3 id="item-11-handle-assignment-to-self-in-operator">Item 11: Handle assignment to self in operator=.</h3>
<blockquote>
<p>✦ Make sure operator= is well-behaved when an object is assigned to itself. Techniques include comparing addresses of source and target objects, careful statement ordering, and copy-and-swap.</p>
<p>✦ Make sure that any function operating on more than one object be- haves correctly if two or more of the objects are the same.</p>
</blockquote>
<p>自己赋值自己看似是个很蠢的想法，但它确实经常出现，例如：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="hl"><span class="lnt">6
</span></span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">vector</span><span class="o">&lt;</span><span class="n">Widget</span><span class="o">&gt;</span> <span class="n">widgets</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">widgets</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">for</span><span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="n">widgets</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="n">j</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line hl"><span class="cl">	<span class="n">widgets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">widgets</span><span class="p">[</span><span class="n">j</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>当 i=j 时，就出现了自己赋值自己的情况。</p>
<p>自己赋值自己可能会出现很多意想不到的情况：如果类的赋值函数的逻辑是先释放资源，再复制资源，这种情况下就会出现复制已经被释放的资源的操作。</p>
<p>为了解决该问题，在赋值运算符实现前先判断两个资源地址是否相同即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="hl"><span class="lnt">2
</span></span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">Widget</span><span class="o">&amp;</span> <span class="n">Widget</span><span class="o">::</span><span class="k">operator</span><span class="o">=</span><span class="p">(</span><span class="k">const</span> <span class="n">Widget</span><span class="o">&amp;</span> <span class="n">rhs</span><span class="p">){</span>
</span></span><span class="line hl"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="k">this</span> <span class="o">==</span> <span class="o">&amp;</span><span class="n">rhs</span><span class="p">)</span>    <span class="k">return</span> <span class="o">*</span><span class="k">this</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-12-copy-all-parts-of-an-object">Item 12: Copy all parts of an object.</h3>
<blockquote>
<p>✦ Copying functions should be sure to copy all of an object’s data members and all of its base class parts.</p>
<p>✦ Don’t try to implement one of the copying functions in terms of the other. Instead, put common functionality in a third function that both call.</p>
</blockquote>
<p>如果我们手动实现了一个类的拷贝函数，又处于某种原因添加了成员，记得及时更新拷贝函数和构造函数，编译器不会给出任何警告。</p>
<p>此外，对于派生类的拷贝函数必须显式调用基类的拷贝函数，否则会调用默认构造函数（对于拷贝构造）或者不拷贝基类成员（对于拷贝赋值）。</p>
<p>两个拷贝函数之间一方调用一方都是无意义的，一个用于初始化一个对象，一个用于拷贝一个对象。可以将重复的代码封装为一个成员函数再调用。</p>
<p>题外话，整本书作者都写得挺幽默的，也很喜欢把编译器拟人化。看下面这段，编译器就跟怨妇一样会抱怨你没听它的话：</p>
<blockquote>
<p>When you declare your own copying functions, you are indicating to compilers that there is something about the default implementations you don’t like. Compilers seem to take offense at this, and they retaliate in a curious fashion: they don’t tell you when your implementations are almost certainly wrong.</p>
<p>That’s their revenge for your writing the copying functions yourself. You reject the copying functions they’d write, so they don’t tell you if your code is incomplete</p>
</blockquote>
<h2 id="resource-management">Resource Management</h2>
<blockquote>
<p>Resource Management A resource is something that, once you’re done using it, you need to return to the system. If you don’t, bad things happen.</p>
</blockquote>
<h3 id="item-13-use-objects-to-manage-resources">Item 13: Use objects to manage resources.</h3>
<blockquote>
<p>✦ To prevent resource leaks, use RAII objects that acquire resources in their constructors and release them in their destructors.</p>
<p>✦ Two commonly useful RAII classes are tr1::shared_ptr and auto_ptr. tr1::shared_ptr is usually the better choice, because its behavior when copied is intuitive. Copying an auto_ptr sets it to null.</p>
</blockquote>
<p>假设我们有个类用于使用资源，其有一个工厂函数用于得到一个资源对象，该函数的调用者 <code>f()</code> 负责释放该对象。即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Investment</span> <span class="p">{</span> <span class="p">...</span> <span class="p">};</span> <span class="c1">// 资源使用类
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">Investment</span><span class="o">*</span> <span class="nf">createInvestment</span><span class="p">();</span> <span class="c1">// 工厂函数
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">f</span><span class="p">()</span> <span class="p">{</span> 
</span></span><span class="line"><span class="cl">	<span class="n">Investment</span> <span class="o">*</span><span class="n">pInv</span> <span class="o">=</span> <span class="n">createInvestment</span><span class="p">();</span> 
</span></span><span class="line"><span class="cl">	<span class="p">..</span>
</span></span><span class="line"><span class="cl">	<span class="k">delete</span> <span class="n">pInv</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>然而，世事并不遂人愿。<code>f</code> 在执行过程中，可能由于 return 语句、异常等导致控制流走不到指针释放的语句，导致对象内存泄漏和资源得不到释放。光凭借人力来手动维护是费时且易出错的。</p>
<p>因此，我们可以把资源交由一个对象来管理，当对象创建，资源随之申请，当对象析构，资源随之释放，即 RAII 模式。可以使用智能指针来管理该资源，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">f</span><span class="p">()</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">unique_ptr</span><span class="o">&lt;</span><span class="n">Investment</span><span class="o">&gt;</span> <span class="n">pInv</span><span class="p">(</span><span class="n">createInvestment</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码阐明了使用对象管理资源的两个要点：</p>
<ul>
<li>一旦资源成功获取，立即移交给管理者对象。</li>
<li>管理者对象通过析构函数来确保资源被正确释放。如果资源在释放过程中引发了异常，参考 <a href=".md#item-8-prevent-exceptions-from-leaving-destructors."> &gt; Item 8 Prevent exceptions from leaving destructors.</a></li>
</ul>
<h3 id="item-14-think-carefully-about-copying-behavior-in-resource-managing-classes">Item 14: Think carefully about copying behavior in resource-managing classes.</h3>
<blockquote>
<p>✦ Copying an RAII object entails copying the resource it manages, so the copying behavior of the resource determines the copying behav- ior of the RAII object.</p>
<p>✦ Common RAII class copying behaviors are disallowing copying and performing reference counting, but other behaviors are possible.</p>
</blockquote>
<p>对于使用一个资源管理对象来拷贝/构造另一个资源管理对象，可以有如下几种行为：</p>
<ul>
<li>禁止拷贝。</li>
<li>引用计数。例如 <code>shared_ptr</code>。单个资源被多个管理类管理，他们共享一个引用计数器。存在的问题是，<code>shared_ptr</code> 当引用计数为 0 时默认行为是调用资源的析构函数，但像 mutex 锁这类的资源，正确的行为是释放这个锁。好在 <code>shared_ptr</code> 提供了设置删除函数的接口，即在初始化时额外传入一个删除函数。</li>
<li>拷贝资源。有些资源是可拷贝的（例如内存），这种情况也能深拷贝这些资源。</li>
<li>移交所有权。</li>
</ul>
<h3 id="item-15-provide-access-to-raw-resources-in-resource-managing-classes">Item 15: Provide access to raw resources in resource-managing classes.</h3>
<blockquote>
<p>✦ APIs often require access to raw resources, so each RAII class should offer a way to get at the resource it manages.</p>
<p>✦ Access may be via explicit conversion or implicit conversion. In gen- eral, explicit conversion is safer, but implicit conversion is more con- venient for clients.</p>
</blockquote>
<p>围绕一个资源，会有许许多多可以调用的 API，我们不可能在管理类中封装这些 API，因此管理类必须提供一个用于获取原始资源的显式或者隐式方法。</p>
<p>显式方法可以是提供一个接口用户获取被管理资源，或者重载 <code>*</code> 或者 <code>-&gt;</code> 运算符，使得可以直接通过这两个运算符访问资源。</p>
<p>隐式方法是提供类型转换函数，使得资源管理对象可以隐式转换为资源对象，这使得程序员可以像使用资源一样直接把资源管理对象传入资源 API，但与此同时的隐式类型转换也带了一些隐藏疑难的问题。</p>
<p>有人可能会觉得直接访问资源破坏了资源管理类对资源的封装，这点我觉得作者解释得很好：<strong>并非所有类都是用来封装的，资源管理类是用来管理资源的获取和释放的</strong>。</p>
<blockquote>
<p>AII classes don’t exist to encapsulate something; they exist to ensure that a particular action — resource release — takes place.</p>
</blockquote>
<h3 id="item-16-use-the-same-form-in-corresponding-uses-of-new-and-delete">Item 16: Use the same form in corresponding uses of new and delete.</h3>
<blockquote>
<p>✦ If you use [] in a new expression, you must use [] in the correspond- ing delete expression. If you don’t use [] in a new expression, you mustn’t use [] in the corresponding delete expression.</p>
</blockquote>
<p>new 做的事情：申请一片空间，调用构造函数。delete 做的事情：调用析构函数，释放一片空间。</p>
<p>对于创建一个对象数组，编译器会将数组长度记录在某个位置（许多保存在空间前的地址），在释放对象数组时，必须使用 <code>delete []</code> 显式告知编译器要删除的是数组，否则是未定义行为（编译器大概率会将其视为单个对象）。</p>
<p>因此 <code>new</code> 和 <code>delete</code>、<code>new []</code> 和 <code>delete []</code> 要配套使用。</p>
<h3 id="item-17-store-newed-objects-in-smart-pointers-in-standalone-statements">Item 17: Store newed objects in smart pointers in standalone statements.</h3>
<blockquote>
<p>✦ Store newed objects in smart pointers in standalone statements. Failure to do this can lead to subtle resource leaks when exceptions are thrown.</p>
</blockquote>
<p>即便使用了智能指针，也可能由于意外导致内存泄漏：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">processWidget</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">tr1</span><span class="o">::</span><span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">Widget</span><span class="o">&gt;</span><span class="p">(</span><span class="k">new</span> <span class="n">Widget</span><span class="p">),</span> <span class="n">priority</span><span class="p">());</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>编译器在对上述函数调用的参数求值的时候，标准并未规定其顺序，因此可能先 <code>new Widget</code>，并在 <code>priority()</code> 中引发异常，只能指针此时还没构造函数调用就异常结束了，进而导致内存泄漏。</p>
<p>解决这个问题的方案也很简单，使用单独的语句保存构造这个智能指针，然后再传入函数调用。</p>
<h2 id="designs-and-declarations">Designs and Declarations</h2>
<p>这一章主要讨论如何设计和实现 C++ 接口。</p>
<h3 id="item-18-make-interfaces-easy-to-use-correctly-and-hard-to-use-incorrectly">Item 18: Make interfaces easy to use correctly and hard to use incorrectly.</h3>
<blockquote>
<p>✦ Good interfaces are easy to use correctly and hard to use incorrectly. You should strive for these characteristics in all your interfaces.</p>
<p>✦ Ways to facilitate correct use include consistency in interfaces and behavioral compatibility with built-in types.</p>
<p>✦ Ways to prevent errors include creating new types, restricting opera- tions on types, constraining object values, and eliminating client re- source management responsibilities.</p>
<p>✦ tr1::shared_ptr supports custom deleters. This prevents the cross - DLL problem, can be used to automatically unlock mutexes (see Item 14), etc.</p>
</blockquote>
<p>一个理想的接口实现是：如果接口调用正常运行，说明一切都按照调用者预期进行，否则给出相应反馈。</p>
<p>例如，如果实现一个日期类，构造函数需要传入年月日，相比直接接收三个 <code>int</code> 参数，一种更好的方案是分别定义年月日的类，这样可以防止用户混淆了月和日。此外，还可以定义 12 个月常量，不要使用枚举定义，而是在月份的类里定义 12 个静态函数，返回这 12 个月的常量。使用静态函数而非静态常量是为了避免 <a href=".md#item-4-make-sure-that-objects-are-initialized-before-they%E2%80%99re-used."> &gt; Item 4 Make sure that objects are initialized before they’re used.</a> 提到的初始化非局部静态常量的问题。</p>
<p>为了防止用户犯错，另一个方案是严格约束一个类可以支持的操作，例如将 <code>operator *</code> 的返回值声明为 <code>const</code>，或者尽可能用 <code>const</code> 修饰函数。这样编译器就可以识别出如下的笔误：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">if</span> <span class="p">(</span><span class="n">obj1</span> <span class="o">+</span> <span class="n">obj2</span> <span class="o">=</span> <span class="n">obj3</span><span class="p">){</span> <span class="c1">// 本意是obj1 + obj2 == obj3
</span></span></span><span class="line"><span class="cl">						<span class="c1">//但写成了将一个变量赋值给另一个临时变量
</span></span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>我们定义的类最好与内建的类型表现出一致的行为，上面这条规则实际上是本条的特例。尽量与内建类型表现一致有助于减少用户的记忆量和犯错的几率。</p>
<p>接口不应该让用户一定要做什么收尾的事情（例如释放资源），因此工厂函数最好不要返回野指针，让用户自行封装，而是直接返回智能指针。该方案还能避免 cross dll 问题（申请和释放内容的代码不在同一个 dll），所有资源都是由申请者进行释放。</p>
<h3 id="item-19-treat-class-design-as-type-design">Item 19: Treat class design as type design.</h3>
<blockquote>
<p>✦ Class design is type design. Before defining a new type, be sure to consider all the issues discussed in this Item.</p>
</blockquote>
<p>好的类型应该有自然的语法、符合直觉的语义和高效的实现。设计类时，要回答好这几个问题：</p>
<ul>
<li>你的对象要怎么构造和销毁？这个问题决定了如何实现构造和析构函数，以及相关的内存申请和释放的函数。</li>
<li>对象的初始化和对象赋值有什么区别？这个问题回答了构造函数和拷贝运算符的区别，不要混淆二者。</li>
<li>如果你的对象按值传递，会发生什么？按值传递将调用拷贝构造函数，这一过程应该符合预期。</li>
<li>你的对象的有效取值有哪些？根据有效值，可以在构造、setter 方法、成员函数中检查是否为有效值。</li>
<li>你的类能否正确处理继承关系和被继承？作为派生类，你需要实现虚函数；作为基类，你需要声明虚函数。</li>
<li>你的类可以转换为什么类型？如果你的类可以隐式转为其它类型，你要么在那个类中声明一个接受你的类的非显式构造函数，或者在你的类中声明一个那个类的类型转换函数。如果你的类只能显式转换为其它类型，你就不能声明类型转换函数或者声明只有一个参数的非显式构造函数，你要么提供一个方法用于转换为其它类型，或者将其他类型的相对应的构造函数声明为 explicit。</li>
<li>哪些函数和运算符对你的类来说是有意义的？这个问题回答了你要实现哪些运算符和函数。</li>
<li>你应该禁用哪些编译器可能会生成的函数？如果你不想让编译器生成某些函数，应该显式将其声明为私有的。</li>
<li>你的成员访问权限应该是怎么样的？这决定了成员的访问权限，以及友元函数和友元类。</li>
<li>你的类有哪些“未声明的接口”？所谓未声明的接口，就是指出了表现出的接口之外，你的类还做出了哪些承诺和保证？例如性能、异常、资源使用等。</li>
<li>你的类泛化性能如何？如果你的类想要泛化出一系列类，那你应该定义模板类。</li>
<li>你真的需要一个类嘛？如果几个函数就能解决你的问题，那你实际上并不需要一个类。</li>
</ul>
<h3 id="item-20-prefer-pass-by-reference-to-const-to-pass-by--value">Item 20: Prefer pass-by-reference-to-const to pass-by- value.</h3>
<blockquote>
<p>✦ Prefer pass-by-reference-to-const over pass-by-value. It’s typically more efficient and it avoids the slicing problem.</p>
<p>✦ The rule doesn’t apply to built-in types and STL iterator and func- tion object types. For them, pass-by-value is usually appropriate.</p>
</blockquote>
<p>默认情况下，函数参数的传递方式为值传递，即实参通过拷贝构造作为形参传递给函数，当函数调用结束时，还需要调用形参的析构函数。这一过程需要浪费大量的时间。</p>
<p>使用 const 引用传递可以避免上述重复的操作，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">foo</span><span class="p">(</span><span class="k">const</span> <span class="n">class_name</span><span class="o">&amp;</span> <span class="n">param</span><span class="p">);</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>const</code> 关键字可以确保调用者传入的参数不被修改。引用则可以实现虚函数的动态绑定。</p>
<p>对于大部分编译器而言，引用传递是通过指针来实现的，因此，对于一些内建类型，使用值传递的性能可能要优于引用传递。同样，对于 STL 中的迭代器，按值传递的性能优于引用传递。</p>
<p>并不是说，一个类很小，所以它就适合按值传递。一个很小的类其拷贝构造函数也可能很耗时。例如，一个 vector 的指针，拷贝构造函数可能要执行深拷贝，它的运行代价是非常非常昂贵的。</p>
<p>即便拷贝构造函数执行得很快，也并不意味着它适合按值传递。一些编译器区别对待内建类型和用户定义的类，后者即便再小也不允许被保存在一个寄存器中，这就隐含了性能问题。</p>
<h3 id="item-21-dont-try-to-return-a-reference-when-you-must-return-an-object">Item 21: Don’t try to return a reference when you must return an object.</h3>
<blockquote>
<p>✦ Never return a pointer or reference to a local stack object, a refer- ence to a heap-allocated object, or a pointer or reference to a local static object if there is a chance that more than one such object will be needed. (Item 4 provides an example of a design where returning a reference to a local static is reasonable, at least in single-threaded environments.)</p>
</blockquote>
<p>引用传递可以提高传递效率，但这并不意味着所有的函数传递都应该使用引用传递。使用引用传递的前提是被传递的对象确实存在。假设实现了一个有理数类 <code>Rational</code>，如果将 <code>operator *</code> 的返回类型定义为引用传递，那么在调用 <code>operator *</code> 前这个对象肯定是不存在的，这就要让函数来创建这个对象。</p>
<p>函数有两种方式来创建一个对象：在栈上或者在堆上，前者会导致返回的引用对象会被销毁，后者会导致需要调用者手动销毁。即便用户记得销毁，如下代码仍然存在内存泄露：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">Rational</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">product</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">product</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span> <span class="o">*</span> <span class="n">z</span><span class="p">;</span> <span class="c1">// x*y返回的临时对象（在堆上）没有被释放
</span></span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="k">delete</span> <span class="n">product</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来介绍一种奇淫巧技，通过静态变量来解决内存泄露问题：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">const</span> <span class="n">Rational</span><span class="o">&amp;</span> <span class="k">operator</span><span class="o">*</span><span class="p">(</span><span class="k">const</span> <span class="n">Rational</span><span class="o">&amp;</span> <span class="n">lhs</span><span class="p">,</span> <span class="k">const</span> <span class="n">Rational</span><span class="o">&amp;</span> <span class="n">rhs</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="n">Rational</span> <span class="n">result</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="n">result</span> <span class="o">=</span> <span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">result</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上面这段代码很“巧妙”地规避了内存泄露问题，但除了很常见的静态变量多线程不安全问题外，<code>(a * b) == (c * d)</code> 这个表达式结果是恒 true 的！！</p>
<h3 id="item-22-declare-data-members-private">Item 22: Declare data members private.</h3>
<blockquote>
<p>✦ Declare data members private. It gives clients syntactically uniform access to data, affords fine-grained access control, allows invariants to be enforced, and offers class authors implementation flexibility.</p>
<p>✦ protected is no more encapsulated than public.</p>
</blockquote>
<p>为什么不把数据类型声明为 public/protected：</p>
<ul>
<li>语法一致性：用户在调用接口/数据时，无需区分调用的是函数还是直接获取了成员变量。</li>
<li>读写权限设置：通过函数获取/写入成员变量时，可以控制每个成员变量的读写权限。</li>
<li>封装：通过对 getter 进行封装，如果需要修改 getter 的实现，用户代码也不需要更改。</li>
<li>便于维护数据：可以防止客户程序直接修改数据变量，破坏结构。</li>
<li>保留了修改的余地：如果后期需要重构这个类，只要保证仍提供相关接口即可，而不需要确保数据成员一定要存在。</li>
</ul>
<h3 id="item-23-prefer-non-member-non-friend-functions-to-member-functions">Item 23: Prefer non-member non-friend functions to member functions.</h3>
<blockquote>
<p>✦ Prefer non-member non-friend functions to member functions. Do- ing so increases encapsulation, packaging flexibility, and functional extensibility.</p>
</blockquote>
<p>先聊聊封装。一个类封装程度越高，意味着其对外暴露的内容越少，同时意味着我们修改一个类的灵活性也就越高（因为只需要维护对外暴露的内容）。提高我们的灵活性，这就是为什么我们要进行封装。</p>
<p>一个数据成员的封装程度越高，意味着它对外暴露得越少。评判一个数据成员对外暴露的程度，就是统计有类成员方法和友元方法引用了这个成员。</p>
<p>因此，当一个需求既可以使用成员函数实现也可以使用非成员且非友元函数实现，最好使用后者，因为这不会降低数据成员的封装程度。</p>
<p>假设我们实现了一个浏览器类 <code>WebBrowser</code>，及相应的清理历史记录、cookies、下载的文件等成员函数。如果我们想些一个 <code>clearAll</code> 函数，根据上面的原则，不应该使用成员函数来实现。</p>
<p>就是说，我们可以定义一个函数来实现 clearAll，或者定义一个工具类并实现一个静态函数 clearAll，这在 Java 中更为常见。在 C++ 中，更地道的方法是将 <code>clearAll</code> 和 <code>WebBrowser</code> 定义在同一个 <code>namespace</code> 中：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">namespace</span> <span class="n">WebBrowserStuff</span> <span class="p">{</span> 
</span></span><span class="line"><span class="cl">	<span class="k">class</span> <span class="nc">WebBrowser</span> <span class="p">{</span> <span class="p">...</span> <span class="p">};</span> 
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="nf">clearBrowser</span><span class="p">(</span><span class="n">WebBrowser</span><span class="o">&amp;</span> <span class="n">wb</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl">	<span class="p">...</span> 
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>得益于 <code>namespace</code> 跨文件的特性，可以将不同的类似 <code>clearAll</code> 的工具函数声明在不同的头文件中。</p>
<h3 id="item-24-declare-non-member-functions-when-type-conversions-should-apply-to-all-parameters">Item 24: Declare non-member functions when type conversions should apply to all parameters.</h3>
<blockquote>
<p>✦ If you need type conversions on all parameters to a function (includ- ing the one that would otherwise be pointed to by the this pointer), the function must be a non-member.</p>
</blockquote>
<p>一般来说，让类支持隐式类型转换并不是个好主意，但凡事都有例外。例如，一个数值型的类要支持来自 <code>int</code> 的隐式转换是合理的。</p>
<p>接下来，当我们实现加法时，多个选项摆在了面前：重载定义成员函数、定义非成员函数、定义友元函数。</p>
<p>如果我们把他定义成一个成员函数，那么允许来自 <code>int</code> 的隐式转换时，<code>Rational * int</code> 是可以通过编译的，但是 <code>int * Rational</code> 是不可以的，因为 <code>int</code> 类型的 <code>operator *</code> 并不支持类型 <code>Rational</code> 的参数。这显然不够优雅，违反了乘法的交换律。</p>
<p>一种解决方案定义非成员函数 <code>const Rational operator*(const Rational&amp; lhs, const Rational&amp; rhs)</code>，当任意一个参数为 <code>int</code> 时，编译器会将其隐式转换为 <code>Rational</code>。</p>
<p>需求实现了，那么问题来了，要不要声明其为友元函数呢？如果可以，就不要声明为友元，因为友元会降低类的封装程度。</p>
<h3 id="item-25-consider-support-for-a-non-throwing-swap">Item 25: Consider support for a non-throwing swap.</h3>
<blockquote>
<p>✦ Provide a swap member function when std::swap would be inefficient for your type. Make sure your swap doesn’t throw exceptions.</p>
<p>✦ If you offer a member swap, also offer a non-member swap that calls the member. For classes (not templates), specialize std::swap, too.</p>
<p>✦ When calling swap, employ a using declaration for std::swap, then call swap without namespace qualification.</p>
<p>✦ It’s fine to totally specialize std templates for user-defined types, but never try to add something completely new to std.</p>
</blockquote>
<p><code>swap</code> 自从在 STL 中引入，就是一个异常安全的函数。其一种经典的实现是：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">namespace</span> <span class="n">std</span> <span class="p">{</span> 
</span></span><span class="line"><span class="cl">	<span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span> 
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">swap</span><span class="p">(</span><span class="n">T</span><span class="o">&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">b</span><span class="p">){</span> 
</span></span><span class="line"><span class="cl">		<span class="n">T</span> <span class="nf">temp</span><span class="p">(</span><span class="n">a</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl">		<span class="n">a</span> <span class="o">=</span> <span class="n">b</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">		<span class="n">b</span> <span class="o">=</span> <span class="n">temp</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">	<span class="p">}</span> 
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>只要类实现了构造函数和拷贝构造函数，上面这个模板函数就用于该类的交换。然而，默认的 <code>swap</code> 函数调用了一次拷贝构造函数和两次拷贝赋值函数，我们可能想根据自己的类定制一个更 fancy 的交换函数。</p>
<p>对于存在类指针数据成员的类来说，拷贝函数进行的深拷贝是不必要的，我们可以在自定义交换函数中执行浅交换，即只要交换指针。注意，这一过程可以通过模板特化进行，而不是完全自定义一个 <code>swap</code> 函数。</p>
<p>但是，模板特化也不能访问私有指针，一种做法是将特化的版本声明为友元函数。然而，更传统的做法是在类中声明一个公有接口 <code>swap</code>，并在模板特化中调用该接口。STL 的容器就是这么实现的。</p>
<p>但是，上述方案并不适用于模板类。具体来说，模板类中存在模板类型 <code>T</code>，在对 <code>swap</code> 进行特化时只能进行部分特化，但 C++ 中模板函数不支持部分特化：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">namespace</span> <span class="n">std</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">swap</span><span class="o">&lt;</span><span class="n">Widget</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="n">Widget</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="n">Widget</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;&amp;</span> <span class="n">b</span><span class="p">){</span>  <span class="c1">// 对swap部分特化是不允许的
</span></span></span><span class="line"><span class="cl">		<span class="n">a</span><span class="p">.</span><span class="n">swap</span><span class="p">(</span><span class="n">b</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>一种方案是对 <code>swap</code> 进行重载（删除 <code>&lt;Widget&lt;T&gt;&gt;</code> 即可），但很遗憾，C++ 标准规定 std 命名空间只能由 C++ 标准委员会进行修改，而重载属于修改，是不被允许的。</p>
<p>似乎所有路都被堵死了？其实没有！别忘记，我们不一定要重载或者特化 <code>std::swap</code>，我们可以直接在 <code>Widget</code> 的命名空间中声明 <code>swap</code> 并使用。得益于 ADL 机制，编译器会自动调用 <code>Widget</code> 所在命名空间的 <code>swap</code>。</p>
<p>上述方案是万能的嘛？很遗憾，又不是。如下的一段代码，当执行交换时，调用的是哪个函数呢？<code>std::swap</code> 还是使用 <code>T</code> 特化的版本？又或者某个命名空间中针对类型 <code>T</code> 的 <code>swap</code>。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span> 
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">doSomething</span><span class="p">(</span><span class="n">T</span><span class="o">&amp;</span> <span class="n">obj1</span><span class="p">,</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">obj2</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span> 
</span></span><span class="line"><span class="cl">	<span class="n">swap</span><span class="p">(</span><span class="n">obj1</span><span class="p">,</span> <span class="n">obj2</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl">	<span class="p">...</span> 
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>你可能想的是：如果有针对类型 <code>T</code> 的 swap，则优先调用，如果没有则回落到 <code>std::swap</code>，在 <code>doSomething</code> 中添加一行就能实现你的需求：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="hl"><span class="lnt">3
</span></span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span> 
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">doSomething</span><span class="p">(</span><span class="n">T</span><span class="o">&amp;</span> <span class="n">obj1</span><span class="p">,</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">obj2</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line hl"><span class="cl">	<span class="k">using</span> <span class="n">std</span><span class="o">::</span><span class="n">swap</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span> 
</span></span><span class="line"><span class="cl">	<span class="n">swap</span><span class="p">(</span><span class="n">obj1</span><span class="p">,</span> <span class="n">obj2</span><span class="p">);</span> 
</span></span><span class="line"><span class="cl">	<span class="p">...</span> 
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>当调用 swap 时，编译器首先会在全局空间或者 <code>T</code> 所在的命名空间寻找参数为 <code>T</code>swap 函数，如果找不到，则会在 <code>std</code> 空间中寻找特化的 swap，如果还是没有，则使用通用的 swap 实现。</p>
<p>本节内容有点多，小结一下：</p>
<ul>
<li>如果通用的 swap 性能可以接受，则没必要自己实现。</li>
<li>如果要自己实现，步骤为：
<ul>
<li>提供一个 swap 成员接口</li>
<li>在类所在的明明空间提供一个 swap 非成员函数，调用 swap 成员函数接口</li>
<li>如果你写的是类不是模板类，则为其特化一个 <code>std::swap</code></li>
</ul>
</li>
<li>当调用 swap 时，确保使用 using 语句，使得 <code>std::swap</code> 是可见的。</li>
</ul>
<p>最后一点忠告：swap 成员函数不应该抛出异常。这是因为 swap 一个很重要的应用就是帮助类提供强异常安全的保证。这一约束仅用于成员函数，非成员函数不受此限制。</p>
<h2 id="implementations">Implementations</h2>
<h3 id="item-26-postpone-variable-definitions-as-long-as-possible">Item 26: Postpone variable definitions as long as possible.</h3>
<blockquote>
<p>✦ Postpone variable definitions as long as possible. It increases pro- gram clarity and improves program efficiency.</p>
</blockquote>
<p>对象的构造和析构过程需要时间，因此，尽可能推延变量的定义，知道接下来必须要使用这个变量。例如下面代码中，提前定义了需要返回的 <code>ret</code> 再判断异常逻辑。当触发异常时，<code>s</code> 的构造和析构是不必要的：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">foo</span><span class="p">(</span><span class="n">string</span> <span class="n">s</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="n">string</span> <span class="n">ret</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="k">throw</span> <span class="nf">logic_error</span><span class="p">(</span><span class="s">&#34;s is empty&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">ret</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>此外，上述代码会将 <code>ret</code> 初始化空串，这也是不必要的，之后对其赋值还会调用拷贝构造函数。更合适的做法是直接把计算出的返回值赋给 <code>ret</code>。</p>
<p>所谓“as long as possible”，不仅仅指的是延迟变量的定义，而是当明确了变量的值之后再定义这个变量。</p>
<p>对于循环中要使用的对象，一般在循环外定义更好，这可以避免多次调用构造和析构函数。</p>
<h3 id="item-27-minimize-casting">Item 27: Minimize casting.</h3>
<blockquote>
<p>✦ Avoid casts whenever practical, especially dynamic_casts in perfor- mance-sensitive code. If a design requires casting, try to develop a cast-free alternative.</p>
<p>✦ When casting is necessary, try to hide it inside a function. Clients can then call the function instead of putting casts in their own code.</p>
<p>✦ Prefer C++-style casts to old-style casts. They are easier to see, and they are more specific about what they do.</p>
</blockquote>
<p>C++ 支持如下格式的类型转换：</p>
<ul>
<li>C 风格：<code>(T) expression</code></li>
<li>函数风格：<code>T(expression)</code></li>
<li>C++ 形式：
<ul>
<li><code>const_cast&lt;T&gt;(expression)</code>：移除一个变量的 const 修饰，只有 <code>const_cast</code> 运算符支持该转换。</li>
<li><code>dynamic_cast&lt;T&gt;(expression)</code>：进行“safe downcasting”，即判断一个基类对象能否安全转换为派生对象，该运算符有较大的运行时开销。</li>
<li><code>reinterpret_cast&lt;T&gt;(expression)</code>：进行两个无关类型之间的转换，即按照比特位重新解析为另外一个对象。该转换除非是面向低层编码，否则不应该使用。</li>
<li><code>static_cast&lt;T&gt;(expression)</code>：进行强制隐式类型转换<br>
建议使用新版的 C++ 形式进行类型转换，一方面这些类型转换语句在代码中更容易识别，另一方面新的四种类型转换功能更加细化，方便查找错误。</li>
</ul>
</li>
</ul>
<p>不同编译器和不同平台的内存排布可能不同，因此不要根据内存排布进行低层的类型转换。</p>
<p><code>static_cast</code> 如果传入的派生类对象，会返回基类对象的拷贝；如果传入派生类指针或引用，会返回基类对象指针或引用。因此，如果要调用基类非 const 成员函数，需要先转换为基类引用或者基类指针，再调用，否则该函数对该对象的修改是不起作用的。</p>
<p><code>dynamic_cast</code> 开销并不小，能避免就避免。可以使用虚函数的动态绑定机制，在不进行类型转换的情况下通过基类指针访问派生类的函数。</p>
<h3 id="item-28-avoid-returning-handles-to-object-internals">Item 28: Avoid returning “handles” to object internals.</h3>
<blockquote>
<p>✦ Avoid returning handles (references, pointers, or iterators) to object internals. Not returning handles increases encapsulation, helps const member functions act const, and minimizes the creation of dangling handles.</p>
</blockquote>
<p>一个成员变量的封装程度也与返回该对象的引用的成员函数的访问权限有关，如果公有函数返回了私有变量，那么这个变量的封装就被破坏为公有的。</p>
<p>如果一个对象内部的数据成员以指针的形式指向外部空间，并且该指针也可以被外部访问，那么即便这个对象被 const 修饰，其成员的内容还是会被修改。</p>
<p>指针、引用、迭代器等都会存在上述问题，他们可以统称为用于用于访问对象的句柄。</p>
<p>上面的两个问题指出了要遵守的规则：成员函数不得返回访问权限比自身更严格的成员变量/函数的句柄，除非有意为之并将返回值声明为 <code>const</code>。</p>
<p>此外，如果一个类的成员函数返回了类内部成员的引用，还可能诱发临时对象销毁后访问问题，即这个类的临时对象调用了这个成员函数，其返回值将在返回后被销毁。例如：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">A</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="n">Data</span> <span class="n">data_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">Data</span><span class="o">&amp;</span> <span class="n">get_data</span><span class="p">()</span> <span class="k">const</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">		<span class="k">return</span> <span class="n">data_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">Data</span><span class="o">*</span> <span class="k">const</span> <span class="n">p_data</span> <span class="o">=</span> <span class="o">&amp;</span><span class="p">(</span><span class="n">A</span><span class="p">().</span><span class="n">get_data</span><span class="p">());</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-29-strive-for-exception-safe-code">Item 29: Strive for exception-safe code.</h3>
<blockquote>
<p>✦ Exception-safe functions leak no resources and allow no data struc- tures to become corrupted, even when exceptions are thrown. Such functions offer the basic, strong, or nothrow guarantees.</p>
<p>✦ The strong guarantee can often be implemented via copy-and-swap, but the strong guarantee is not practical for all functions.</p>
<p>✦ A function can usually offer a guarantee no stronger than the weak- est guarantee of the functions it calls.</p>
</blockquote>
<p>当一个异常被跑出，异常安全的函数应该做到：</p>
<ul>
<li>没有资源泄露。资源泄露不仅仅是内存泄露，还包括锁等资源。这一点可以通过 <a href=".md#item-13-use-objects-to-manage-resources."> &gt; Item 13 Use objects to manage resources.</a> 中的 RAII 做到。</li>
<li>数据结构没有被破坏。即需要维护的数据结构仍然保持维护的状态。</li>
</ul>
<p>异常安全的函数满足以下三种特性之一：</p>
<ul>
<li>最基本的保证：如果抛出异常，程序内的所有状态都是合法且有效的，但无法预知这些状态的取值。</li>
<li>强力保证：如果抛出异常，程序内所有的状态和函数调用前相同。这样的函数我们称之为原子函数。</li>
<li>不抛出异常保证：函数保证在执行过程中不会抛出异常。内建类型的所有操作都是这样的函数。</li>
</ul>
<p>需要注意的是，类似 <code>void foo() noexcept;</code> 这样的函数声明并不意味着该函数保证不会抛出异常，这个声明意味着如果抛出了异常，那是致命的错误。相反，该函数甚至可能无法提供任何级别的异常安全保证。</p>
<p>函数是否是异常安全的并不取决于它的函数声明，而是取决于其具体实现。确保不抛出异常是很困难的，尤其是当使用 C++ 的各种库时，通常只要实现稍弱的两种保证即可。</p>
<p>要想提供异常安全的强力保证，通常会使用到 <code>swap and copy</code> 技术，即先对要修改的对象的拷贝进行修改，没有异常再交换二者。</p>
<p>一旦涉及到函数彼此调用，想要实现强力保证就很快困难，即便被调用的函数能够提供强力保证。在下面的代码中，<code>foo</code> 调用了 <code>f1</code> 和 <code>f2</code>，如果 <code>f1</code> 正常调用结束，但 <code>f2</code> 发生了异常而回退，此时需要由 <code>foo</code> 追踪 <code>f1</code> 的修改内容并进行回退——这显然相当困难。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">foo</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="n">f1</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="n">f2</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>异常安全的强力保证需要消耗大量的资源和性能，并不适用于所有的场景。这种情况下，我们就要转向基本保证。</p>
<p>但基本保证也不是一件易事，仍旧考虑上面那个调用两个函数的例子，如果 <code>f1</code> 是异常不安全的，那么当其排除异常时，内部可能存在资源泄露，这对于调用者 <code>foo</code> 来说是无法定位并释放的。因此，如果一个函数调用了异常不安全的函数，那其也无法提供异常安全的保证。</p>
<p>同样的，对于一个系统来说，其要么是异常安全要么是异常危险的，不可能介于二者之间。一旦这个系统中有一个函数是异常危险的，这个系统就不可能是异常安全的。</p>
<h3 id="item-30-understand-the-ins-and-outs-of-inlining">Item 30: Understand the ins and outs of inlining.</h3>
<blockquote>
<p>✦ Limit most inlining to small, frequently called functions. This facili- tates debugging and binary upgradability, minimizes potential code bloat, and maximizes the chances of greater program speed.</p>
<p>✦ Don’t declare function templates inline just because they appear in header files.</p>
</blockquote>
<p>内联函数除了可以减少函数调用开销，还可以给予编译器更大的优化空间。</p>
<p>但是，启用内联，也会让目标文件变得更大（所有调用内链函数的地方都会被展开），增加换页次数、降低 cache 命中率。</p>
<p><code>inline</code> 是向编译器建议，而不是强制要求编译器将该函数处理为内联函数。有两种方式向编译器提出建议：隐式，即在类中给出成员/友元函数的定义；显式，即在函数定义处使用 <code>inline</code> 关键字。</p>
<p>编译器要在编译器将内联函数调用原地展开，因此内链函数必须在头文件中给出。模板函数也是如此。但这并不意味着模板函数和内联函数之间存在什么充分必要关系。</p>
<p>库的设计者应该评估是否将一个接口声明为 <code>inline</code>，如果这样做，一旦需要对内联函数的实现进行修改，所有调用该函数的代码也需要被重新编译。修改一个普通函数，则仅仅需要重新链接。</p>
<h3 id="item-31-minimize-compilation-dependencies-between-files">Item 31: Minimize compilation dependencies between files.</h3>
<blockquote>
<p>✦ The general idea behind minimizing compilation dependencies is to depend on declarations instead of definitions. Two approaches based on this idea are Handle classes and Interface classes.</p>
<p>✦ Library header files should exist in full and declaration-only forms. This applies regardless of whether templates are involved.</p>
</blockquote>
<p>当我们修改一个类的具体实现后，所有直接和间接依赖这个类的文件都会被重新编译。这是因为 C++ 中的接口和实现没有很好地分离。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;data.h&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Person</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">Date</span><span class="o">&amp;</span> <span class="n">get_birthdate</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span> <span class="c1">// interface
</span></span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">Date</span> <span class="n">birthdate_</span><span class="p">;</span> <span class="c1">// implementation detail
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>例如，<code>Person</code> 类中有接口 <code>get_birthdate</code>，其私有成员变量 <code>Date birthdate_</code> 就是一个实现，在编译 <code>Persion</code> 时，必须知道 <code>Date</code> 的具体实现，才能顺利编译。这是因为必须在 <code>Person</code> 中给 <code>Date</code> 成员预留出足够的空间，而不知道其具体实现，则无法获知其大小。</p>
<p><strong>解决方案一：句柄类</strong></p>
<p>在 Java 中，则不存在上述困扰。当在 Java 中定义一个类时，类成员以指针的形式保存在类中，而不为其预留完整空间。</p>
<p>可以使用 C++ 模拟这一过程，这被称为“pimpl idiom”（point to implementation），具体为：将原本 <code>Person</code> 在头文件中的定义分为接口 <code>Person</code> 和实现 <code>PersonImpl</code> 两个类，前者只声明对外的接口和一个指向具体实现类的指针，后者定义具体的数据成员和接口实现。即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="c1">// person.h
</span></span></span><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&lt;memory&gt;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Date</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Person</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">Date</span><span class="o">&amp;</span> <span class="n">get_birthdate</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span> <span class="c1">// interface
</span></span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">PersonImpl</span><span class="o">&gt;</span> <span class="n">pImpl_</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>需要注意的是，这里使用了前向声明（forward declaration）技术</p>
<p>pimpl idiom 技术的核心理念是：将对实现的依赖转换为对声明的依赖。根据该理念，可以导出两个技巧：</p>
<ul>
<li>如果能使用对象指针或者引用，就不要直接使用对象。声明一个对象需要该对象的定义，但是指针和引用只需要声明。</li>
<li>尽可能依赖声明而非实现。即便是某个函数的参数类型或者返回值类型，是可以直接声明为该类而不需要一定声明为指针或者引用的。</li>
<li>一个类分别要提供声明和定义两个头文件。调用者要包含声明的头文件而非前向声明某个类。</li>
</ul>
<p><strong>解决方案二：接口类</strong><br>
除了 pimpl idiom，另一种处理方式是将 <code>Person</code> 声明为一种特殊的抽象基类——接口，其作用是为派生类指定必须实现公有函数接口。通常来说，接口没有数据成员，没有构造函数，一个虚拟析构函数和一系列纯虚函数。</p>
<p>C++ 中的接口不如 Java 中的限制严格，允许接口具有数据成员。</p>
<p><code>Person</code> 接口可以声明为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Date</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Person</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="o">~</span><span class="n">Person</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="k">const</span> <span class="n">Date</span><span class="o">&amp;</span> <span class="n">get_birthdate</span><span class="p">()</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意，这个类的使用者只能使用 <code>Person</code> 的引用或者指针。按照这种方式实现的 <code>Person</code>，除非其接口有所改变，否则即便 <code>Person</code> 的实现修改调用者也不用重新编译。</p>
<p>接口的调用者需要一个用于创建对象的方法，常用的方式是提供一个静态工厂函数接口用于创建一个对象，并返回相应的智能指针。这个工厂函数可以工具参数返回这个接口的不同派生对象。</p>
<p>注意，由于工厂函数是一个静态函数，并不依赖于具体的数据成员或者方法，因此其所在的类仍旧是一个抽象类/接口。</p>
<p>当然，上述方案减少了头文件之间的依赖，代价是增大了对象的体积，略微减慢了运行速度。</p>
<p>句柄类的解决方案每次访问对象，都要进行一次指针访问操作；接口类的解决方案中，每个函数都是虚函数，每次访问接口函数，都有一次虚函数调用的开销。</p>
<h2 id="inheritance-and-object-oriented-design">Inheritance and Object-Oriented Design</h2>
<p>这一章将集中介绍 C++ 中面向对象相关的内容，包括继承、派生和虚函数。C++ 中的 OOP 遵循 OOP 的基本理念，但又与其他语言的 OOP 有所不同。只有正确理解 C++ 中的 OOP，才能把“所想”通过 C++ 变成“所得”。</p>
<h3 id="item-32-make-sure-public-inheritance-models-is-a">Item 32: Make sure public inheritance models “is-a.”</h3>
<blockquote>
<p>✦ Public inheritance means “is-a.” Everything that applies to base classes must also apply to derived classes, because every derived class object is a base class object.</p>
</blockquote>
<p><strong>公有继承意味着“is-a”关系</strong>，也就是说，类型 <code>D</code> 的所有对象也是类型 <code>B</code> 的对象。前面说的是 OOP 最基本的理念，必须要记住。</p>
<p>C++ 中，需要基类对象的地方也可以传入派生类对象，当且仅当是公有继承才允许。</p>
<p>is-a 关系很容易被直觉和不精确误导：众所周知，企鹅是一种鸟，并且鸟会飞，根据上述想法，不难写出如下代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Bird</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">fly</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Penguin</span><span class="o">:</span> <span class="k">public</span> <span class="n">Bird</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>但事实上，企鹅并不会飞。这一问题的根源在于并不是所有的鸟都会飞，语言的表述是不准确的。更合理的做法是，派生出一个 <code>FlyingBird</code> 类，并在该类中声明虚函数 <code>fly</code>。当然，一切取决于需求，如果不需要使用 <code>fly</code> 这个行为，就没必要派生出 <code>FlyingBird</code> 这个抽象类。</p>
<p>is-a 关系与数学上的特殊 - 一般关系也不相同，例如数学上正方形是一种特殊的长方形，但在 C++ 的公有继承不能这么实现。公有继承 is-a 关系指的是，派生类满足基类的一切性质，而正方形的长宽必须一致，这一特性导致长方形的某些方法并不适用于正方形。</p>
<h3 id="item-33-avoid-hiding-inherited-names">Item 33: Avoid hiding inherited names.</h3>
<blockquote>
<p>✦ Names in derived classes hide names in base classes. Under public inheritance, this is never desirable.</p>
<p>✦ To make hidden names visible again, employ using declarations or forwarding functions.</p>
</blockquote>
<p>在类继承中，同样存在名称遮蔽，即派生类中的成员会遮蔽基类中的同名成员。对于成员变量来说，一切都符合直觉，但对于成员函数来说，就不是这么一回事了。</p>
<p>首先，成员函数的遮蔽是以函数名为标志的，也就是说，派生类中的成员函数除了会遮蔽基类中签名相同的同名函数，还会遮蔽基类中同名的重载函数。例如：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">fun1</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="nf">fun1</span><span class="p">(</span><span class="kt">int</span> <span class="n">x</span><span class="p">);</span> <span class="c1">// 重载
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">fun1</span><span class="p">();</span> <span class="c1">// 遮蔽了基类中所有名为fun1的成员函数
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="cm">/**********************************/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="n">x</span><span class="o">=</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">Derived</span> <span class="n">d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">d</span><span class="p">.</span><span class="n">fun1</span><span class="p">(</span><span class="n">x</span><span class="p">);</span> <span class="c1">// 不合法
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>C++ 的这一默认行为既不符合直觉，也不符合公有继承是 is-a 的关系。为了使重载函数仍旧在派生类中可见，可以在派生类中添加一行 using 语句：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">fun1</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="nf">fun1</span><span class="p">(</span><span class="kt">int</span> <span class="n">x</span><span class="p">);</span> <span class="c1">// 重载
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">using</span> <span class="n">Base</span><span class="o">::</span><span class="n">fun1</span><span class="p">;</span> <span class="c1">// 基类中所有名为fun1的成员都在派生类中可见
</span></span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="nf">fun1</span><span class="p">();</span> 
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="cm">/**********************************/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="n">x</span><span class="o">=</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">Derived</span> <span class="n">d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">d</span><span class="p">.</span><span class="n">fun1</span><span class="p">(</span><span class="n">x</span><span class="p">);</span> <span class="c1">// 合法
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>“在派生类中只继承基类重载成员函数某几个版本”这一想法在<strong>公有继承</strong>中违反了 is-a 理念，但在<strong>私有继承</strong>中，这个需求是合理的。如果在上面代码中，私有继承的派生类只想继承 <code>fun1</code> 的无参版本，可以使用转发：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">fun1</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="nf">fun1</span><span class="p">(</span><span class="kt">int</span> <span class="n">x</span><span class="p">);</span> <span class="c1">// 重载
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">fun1</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">	<span class="p">{</span><span class="n">Base</span><span class="o">::</span><span class="n">fun1</span><span class="p">();}</span> <span class="c1">// 转发
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="cm">/**********************************/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="n">x</span><span class="o">=</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">Derived</span> <span class="n">d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">d</span><span class="p">.</span><span class="n">fun1</span><span class="p">(</span><span class="n">x</span><span class="p">);</span> <span class="c1">// 不合法
</span></span></span><span class="line"><span class="cl"><span class="n">d</span><span class="p">.</span><span class="n">fun1</span><span class="p">();</span> <span class="c1">// 合法，调用的是Derived::fun1()
</span></span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-34-differentiate-between-inheritance-of-interface-and-inheritance-of-implementation">Item 34: Differentiate between inheritance of interface and inheritance of implementation.</h3>
<blockquote>
<p>✦ Inheritance of interface is different from inheritance of implementa- tion. Under public inheritance, derived classes always inherit base class interfaces.</p>
<p>✦ Pure virtual functions specify inheritance of interface only.</p>
<p>✦ Simple (impure) virtual functions specify inheritance of interface plus inheritance of a default implementation.</p>
<p>✦ Non-virtual functions specify inheritance of interface plus inherit- ance of a mandatory implementation.</p>
</blockquote>
<p>在 C++ 的继承过程中，需要区分继承一个接口和继承一个函数。前者指的是，只继承这个函数的声明，而不继承基类中的实现（通常也不存在该实现），后者指的是同时继承声明和实现，同时还要区分能否重写（override）该函数。</p>
<p>如果只需要继承来自基类的接口，可以在基类中将该接口声明为纯虚函数（事实上接口也就应该是纯虚函数）。一个冷知识是，纯虚函数同样可以在基类中给出定义，只是在调用时要显式指定，例如 <code>Base::fun()</code>。</p>
<p>如果需要继承一个实现，同时允许在派生类中重写该方法，可以在基类中将该方法声明为虚函数。在实践过程中，往往会由于一个基类的多个派生类的同一个方法具有相同的实现，因此将其作为基类的默认实现。但这也为未来埋下了隐患：之后派生出的某个类并不适用该实现，但是重写该方法了，在编译阶段不会发现这个错误。解决方案为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">fun</span><span class="p">()</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span><span class="c1">// 改为纯虚函数
</span></span></span><span class="line"><span class="cl"><span class="k">protected</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">default_fun</span><span class="p">()</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">fun</span><span class="p">()</span> <span class="c1">// 转发到默认函数
</span></span></span><span class="line"><span class="cl">	<span class="p">{</span><span class="n">default_fun</span><span class="p">();}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>解决方案就是将原函数声明为纯虚函数，并提供一个非同名默认实现函数，在需要使用该默认实现的派生类中，重写该方法，转发到基类默认实现。</p>
<p>有些人不喜欢上面将声明和实现写在两个函数中的方案，转而在基类中为纯虚函数提供一个定义来实现该需求。这是可行的，但在默认实现的权限控制上不如上面这个方案细粒度高。</p>
<p>如果需要继承一个实现，同时禁止在派生类中重写该方法，那么就应该将该方法声明为非虚函数，并使用 <code>final</code> 关键字。</p>
<h3 id="item-35-consider-alternatives-to-virtual-functions">Item 35: Consider alternatives to virtual functions.</h3>
<blockquote>
<p>✦ Alternatives to virtual functions include the NVI idiom and various forms of the Strategy design pattern. The NVI idiom is itself an ex- ample of the Template Method design pattern.</p>
<p>✦ A disadvantage of moving functionality from a member function to a function outside the class is that the non-member function lacks ac- cess to the class’s non-public members.</p>
<p>✦ tr1::function objects act like generalized function pointers. Such ob- jects support all callable entities compatible with a given target sig- nature.</p>
</blockquote>
<p>虚函数在实现过程中被尝尝使用，但实际上其也有几种替代品：</p>
<ul>
<li>通过非虚接口实现模板方法模式<br>
这里的非虚接口来自一种理念：虚拟函数应该是私有的。所谓模板方法模式是一种设计模式，指的是在父类中定义了一个算法的框架，允许子类在不改变算法结构的情况下重写算法中的某些步骤。具体来说，在基类中提供一个非虚接口，其实现是调用某几个特定的私有虚函数，在派生类中，通过修改这几个私有虚函数的实现以修改派生类中的行为。</li>
</ul>
<p>这一设计模式的好处是可以在公有接口中在调用私有接口前后添加一些自定义内容，例如初始化环境、打日志、检查返回值、申请释放锁等。这一模式是控制反转的提现：高层抽象类负责控制基本流程顺序，低层派生类负责控制每个流程的具体实现。</p>
<ul>
<li>通过函数指针实现策略模式<br>
前面提到的模板方法的解决方案，仍旧用到了虚函数（尽管其是私有的），一种更灵活的解决方案是要求派生类在构造基类时传入一个函数指针，基类在实现相关方法时，将调用该函数。</li>
</ul>
<p>其灵活性体现在，即便是同一派生类的不同实例，也可以具有不同的函数实现。</p>
<p>起问题在于，作为非成员函数，该函数无法访问类中的非公有变量。解决方案是降低这个类的封装程度，例如将该函数声明为友元函数，或者提供访问这些变量的公有方法。</p>
<ul>
<li>通过 <code>std::function</code> 实现策略模式<br>
函数指针的实现方案灵活度不够高：参数必须完美匹配，并且只支持常规函数。对其稍加改造，使用 <code>std::function</code> 来替代函数指针，则支持各种可调用的对象（函数对象、lambda 函数、成员函数等），且支持自动类型转换。</li>
</ul>
<h3 id="item-36-never-redefine-an-inherited-non-virtual-function">Item 36: Never redefine an inherited non-virtual function.</h3>
<blockquote>
<p>✦ Never redefine an inherited non-virtual function.</p>
</blockquote>
<p>非虚函数使用的是静态绑定，即基类指针分别指向基类和派生类对象，调用同一个非虚函数，如果这个函数在派生类中被重新定义了，那么二者调用的版本是不同的。这并不符合面向对象的设计原则：</p>
<ul>
<li>前面提到，非虚函数的含义是为该类指定了某种特定实现，这种实现不应该在派生类中修改。如果有修改的需求，应该将其指定为虚函数。</li>
<li>前面提到，公有继承是 is-a 关系，如果在派生类中要重定义某个函数，说明派生对象 is not a 基类对象，与 is-a 关系矛盾。</li>
</ul>
<h3 id="item-37-never-redefine-a-functions-inherited-default-parameter-value">Item 37: Never redefine a function’s inherited default parameter value.</h3>
<blockquote>
<p>✦ Never redefine an inherited default parameter value, because default parameter values are statically bound, while virtual functions — the only functions you should be redefining — are dynamically bound.</p>
</blockquote>
<p>不要修改函数继承的默认参数，这个条款乍一看很奇怪，这实际上是 C++ 中为了更高效地实现虚函数而出现的一种特性，即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">show</span><span class="p">(</span><span class="n">string</span> <span class="n">str</span><span class="o">=</span><span class="s">&#34;Base&#34;</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">cout</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;call Base::show &#34;</span><span class="o">&lt;&lt;</span> <span class="n">str</span> <span class="o">&lt;&lt;</span> <span class="n">endl</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">show</span><span class="p">(</span><span class="n">string</span> <span class="n">str</span><span class="o">=</span><span class="s">&#34;Derived&#34;</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">cout</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;call Derived::show &#34;</span><span class="o">&lt;&lt;</span> <span class="n">str</span> <span class="o">&lt;&lt;</span> <span class="n">endl</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="cm">/*********/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">Derived</span> <span class="n">d</span> <span class="o">=</span> <span class="n">Derived</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="n">Base</span> <span class="o">&amp;</span><span class="n">pd</span> <span class="o">=</span> <span class="n">d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">pd</span><span class="p">.</span><span class="n">show</span><span class="p">();</span> <span class="c1">// output: call Derived::show base
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>具有默认参数的虚函数在进行动态绑定时，其默认参数是静态绑定的。这就造成了上面这几行代码中，的确调用了派生类中重写了的 <code>show</code> 函数，但是传入的默认函数是来自 <code>pd</code> 静态的类型 <code>Base</code> 中对应方法的参数。这一特性是为了减少虚函数表中需要维护的内容，但也导致了其不符合直觉的行为。</p>
<p>这种情况下，在派生类中将待重写的虚函数的参数列表照抄基类中的列表也是不合适的（未来可能修改参数的默认值）。一种解决方案是使用前文提到过的非虚接口：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">show</span><span class="p">(</span><span class="n">string</span> <span class="n">str</span><span class="o">=</span><span class="s">&#34;Base&#34;</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">do_show</span><span class="p">(</span><span class="n">str</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">do_show</span><span class="p">(</span><span class="n">string</span> <span class="n">str</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">cout</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;call Base::do_show&#34;</span> <span class="o">&lt;&lt;</span> <span class="n">str</span> <span class="o">&lt;&lt;</span> <span class="n">endl</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">do_show</span><span class="p">(</span><span class="n">string</span> <span class="n">str</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">cout</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;call Derived::do_show&#34;</span> <span class="o">&lt;&lt;</span> <span class="n">str</span> <span class="o">&lt;&lt;</span> <span class="n">endl</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="cm">/*********/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">Derived</span> <span class="n">d</span> <span class="o">=</span> <span class="n">Derived</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="n">Base</span> <span class="o">&amp;</span><span class="n">pd</span> <span class="o">=</span> <span class="n">d</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">pd</span><span class="p">.</span><span class="n">show</span><span class="p">();</span> <span class="c1">// output: call Derived::do_show base
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>由于非虚函数不可在派生类中重写/遮蔽，因此 <code>show</code> 的默认参数只能为值 base。</p>
<h3 id="item-38-model-has-a-or-is-implemented-in-terms--of-through-composition">Item 38: Model “has-a” or “is-implemented-in-terms- of” through composition.</h3>
<blockquote>
<p>✦ Composition has meanings completely different from that of public inheritance.</p>
<p>✦ In the application domain, composition means has-a. In the implementation domain, it means is-implemented-in-terms-of.</p>
</blockquote>
<p>组合关系（composition）指的是一个物体由多个对象组合而来，或者一个对象包含了其他对象的关系。与公有继承意味着 is-a 类似，组合关系意味着 has-a 或者 is-implemented-in-terms-of（基于 xxx 而实现）。</p>
<p>组合关系的这两层含义，对应着两种不同领域：has-a 常用于对现实世界建模，is-implemented-in-terms-of 常用语纯粹的实现领域，例如实现锁、二叉树等等。</p>
<p>区别 has-a 和 is-a 比较简单，但区分 is-implemented-in-terms-of 和 is-a 就有说法了。例如，当我们需要使用链表来实现集合时，这是哪种关系呢？如果 D is-a B，那么对于 B 成立的说法，对 D 都应该成立，但是链表允许有重复值，集合则允许，因此不是 is-a 关系。</p>
<h3 id="item-39-use-private-inheritance-judiciously">Item 39: Use private inheritance judiciously.</h3>
<blockquote>
<p>✦ Private inheritance means is-implemented-in-terms of. It’s usually inferior to composition, but it makes sense when a derived class needs access to protected base class members or needs to redefine inherited virtual functions.</p>
<p>✦ Unlike composition, private inheritance can enable the empty base optimization. This can be important for library developers who strive to minimize object sizes</p>
</blockquote>
<p>私有继承有如下两个影响：</p>
<ul>
<li>派生类对象不允许被转换为基类对象；</li>
<li>基类成员在派生类中的访问权限为私有。</li>
</ul>
<p>上面两个特性决定了，私有继承的含义为 is-implemented-in-terms-of，它和组合的一种含义相同。只有在迫不得已时，才应该使用私有继承，通常应该使用组合。</p>
<p>迫不得已？例如要使用基类保护成员，或者要重写虚函数的情况。</p>
<h3 id="item-40-use-multiple-inheritance-judiciously">Item 40: Use multiple inheritance judiciously.</h3>
<blockquote>
<p>✦ Multiple inheritance is more complex than single inheritance. It can lead to new ambiguity issues and to the need for virtual inheritance.</p>
<p>✦ Virtual inheritance imposes costs in size, speed, and complexity of initialization and assignment. It’s most practical when virtual base classes have no data.</p>
<p>✦ Multiple inheritance does have legitimate uses. One scenario in- volves combining public inheritance from an Interface class with private inheritance from a class that helps with implementation.</p>
</blockquote>
<p>如果使用多继承，很容易出现名称相同（歧义）的情况。C++ 在解析对重载函数的调用时，首先搜索最佳匹配的函数，然后再检查其权限。这就导致了，即使同名的两个函数一者是私有的，编译器仍旧不能正确对多继承中的同名函数正确解析。</p>
<p>为了解决这种歧义，在函数调用时必须显式指出调用的是哪个基类下的函数名。</p>
<p>在多继承中，同一个基类可能沿着不同的路径被继承了多次，这些数据在最终的派生类中可以有两套独立的副本，也可以共享一个副本（虚继承）。被虚继承的基类称为虚基类。</p>
<p>一般来说，所有的公有继承都应该是虚继承的。但是，虚继承本身存在性能代价：一方面，编译器需要为虚基类维护更多的信息，另一方面，在初始化时派生类的作者必须了解到有哪些虚基类，并为其手动初始化。</p>
<p>因此，虚基类能不用就不用，即便要用，虚基类中的数据成员能少就少。</p>
<p>多继承的一个合理的使用场景是：公有继承一个接口，同时私有继承一个类帮助实现这个接口。之所以要私有继承一个类，是因为要修改其的虚函数，否则使用组合即可。</p>
<h2 id="templates-and-generic-programming">Templates and Generic Programming</h2>
<p>从最初的容器开始，模板进入程序员的世界。后来人们发现模板的能力远不止于此，一种新的编程范式——模板变成应运而生。随后 C++ 中的模板又被证明为是图灵完备的，一种在编译期运行的程序——模板元变成又诞生了。</p>
<h3 id="item-41-understand-implicit-interfaces-and-compile-time-polymorphism">Item 41: Understand implicit interfaces and compile-time polymorphism.</h3>
<blockquote>
<p>✦ Both classes and templates support interfaces and polymorphism.</p>
<p>✦ For classes, interfaces are explicit and centered on function signa- tures. Polymorphism occurs at runtime through virtual functions.</p>
<p>✦ For template parameters, interfaces are implicit and based on valid expressions. Polymorphism occurs during compilation through tem- plate instantiation and function overloading resolution.</p>
</blockquote>
<p>在面向对象编程中，显式接口和运行时多态是重要的组成部分。在泛型编程中，二者仍然生效，但更重要的是隐式接口和编译时多态。</p>
<ul>
<li>在泛型编程中，所谓隐式接口就是对类型 <code>T</code> 执行的所有操作、调用的所有方法</li>
<li>所谓编译时多态指的是对 <code>T</code> 的实例化类型参数的不同会导致调用不同的方法，这就实现了多态</li>
</ul>
<p>通过各种表达式，可以为类型 <code>T</code> 声明其必须支持的接口有哪些。更严谨的说法是，<code>T</code> 必须要支持一些接口，使得这些表达式合法。例如：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="k">do</span><span class="p">(</span><span class="n">T</span> <span class="o">&amp;</span><span class="n">w</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">w</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">10</span> <span class="o">&amp;&amp;</span> <span class="n">w</span><span class="o">!=</span><span class="n">xxx</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>这几行代码并非意味着 <code>T</code> 必须支持返回一个可以与 10 比较类型的值的 <code>size</code> 方法，实际上，它只需要返回一个支持/重载了运算符 <code>operator &gt;</code> 且接受参数 10 的类型即可。同样的，<code>T</code> 也不一定要重载 <code>!=</code> 运算符，只要 <code>w</code> 可以转换为某个类型 <code>X</code> 并且 <code>xxx</code> 可以转换为某个类型 <code>Y</code>，且 <code>X!=Y</code> 这个运算符有定义即可。</p>
<h3 id="item-42-understand-the-two-meanings-of-typename">Item 42: Understand the two meanings of typename.</h3>
<blockquote>
<p>✦ When declaring template parameters, class and typename are inter- changeable.</p>
<p>✦ Use typename to identify nested dependent type names, except in base class lists or as a base class identifier in a member initializa- tion list.</p>
</blockquote>
<p>在声明模板参数中，<code>typename</code> 和 <code>class</code> 是等价的。有些程序员会区分使用二者，例如只接受类的参数使用 <code>class</code>，接受一切类型的参数使用 <code>class</code>。</p>
<p>模板中，由于不知道参数的具体类型，很容易引发歧义：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="hl"><span class="lnt">3
</span></span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">C</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">foo</span><span class="p">(</span><span class="k">const</span> <span class="n">C</span><span class="o">&amp;</span> <span class="n">container</span><span class="p">){</span>
</span></span><span class="line hl"><span class="cl">	<span class="n">C</span><span class="o">::</span><span class="n">const_iterator</span> <span class="o">*</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码的第四行中，本意是声明一个指向类型为 <code>C::const_iterator</code> 的指针 <code>x</code>，但如果类 <code>C</code> 中恰好存在一个名为 <code>const_iterator</code> 的静态成员变量，并且恰好存在一个名为 <code>x</code> 的全局变量，这样代码的含义就变为了两个表达式相乘。编译器必须考虑各种可能性，默认情况下，其不会将类中的名称，例如 <code>C::const_iterator</code> 视为一个类型名称。需要在前面使用 <code>typename</code> 关键字修饰，这样编译器将把其视为类型明对待：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">C</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">foo</span><span class="p">(</span><span class="k">const</span> <span class="n">C</span><span class="o">&amp;</span> <span class="n">container</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">typename</span> <span class="n">C</span><span class="o">::</span><span class="n">const_iterator</span> <span class="o">*</span><span class="n">x</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意：此处 <code>typename</code> 不可使用 <code>class</code> 替换。</p>
<p>“在模板参数类的内嵌类型名前需要使用 <code>typename</code> 修饰”这一规则的一个例外是：在类继承的基类列表名和类初始化列表中，不得使用 <code>typename</code> 修饰：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="c1">// Derived继承了Base&lt;T&gt;中的一个内嵌类Nested
</span></span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="n">Nested</span><span class="p">{</span> <span class="c1">// 基类列表，不可使用typename修饰
</span></span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">explicit</span> <span class="n">Derived</span><span class="p">(</span><span class="kt">int</span> <span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="o">:</span><span class="n">Base</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="n">Nested</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1">// 初始化列表，不可使用typename修饰
</span></span></span><span class="line"><span class="cl">	<span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-43-know-how-to-access-names-in-templatized-base-classes">Item 43: Know how to access names in templatized base classes.</h3>
<blockquote>
<p>In derived class templates, refer to names in base class templates via a “this-&gt;” prefix, via using declarations, or via an explicit base class qualification.</p>
</blockquote>
<p>当需要继承模板基类并访问其中的方法时，编译器会拒绝访问：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="hl"><span class="lnt">28
</span></span><span class="lnt">29
</span><span class="lnt">30
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">BaseA</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">do_foo1</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="nf">do_foo2</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">// 模板基类，有foo1和foo2两个接口
</span></span></span><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">BaseName</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">foo1</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">		<span class="n">BaseName</span> <span class="n">base</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="n">base</span><span class="p">.</span><span class="n">do_foo1</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="nf">foo2</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">		<span class="n">BaseName</span> <span class="n">base</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="n">base</span><span class="p">.</span><span class="n">do_foo2</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">// 派生类，调用foo2方法
</span></span></span><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">BaseName</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="o">&lt;</span><span class="n">BaseName</span><span class="o">&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">call_foo2</span><span class="p">(){</span>
</span></span><span class="line hl"><span class="cl">		<span class="n">do_foo2</span><span class="p">();</span> <span class="c1">// invalid
</span></span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>之所以第 28 行不能正确访问基类中的 <code>do_foo2</code> 方法，是因为模板特化的存在。在模板基类中的一个特化版本可能没有提供 <code>do_foo2</code> 方法，因此编译器拒绝编译该代码。</p>
<p>有如下三种解决方案：</p>
<ul>
<li>在函数调用前使用 <code>this-&gt;</code>：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="hl"><span class="lnt">5
</span></span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">BaseName</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="o">&lt;</span><span class="n">BaseName</span><span class="o">&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">call_foo2</span><span class="p">(){</span>
</span></span><span class="line hl"><span class="cl">		<span class="k">this</span><span class="o">-&gt;</span><span class="n">do_foo2</span><span class="p">();</span> <span class="c1">// valid
</span></span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>使用 <code>using</code> 声明该方法：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="hl"><span class="lnt">4
</span></span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">BaseName</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="o">&lt;</span><span class="n">BaseName</span><span class="o">&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line hl"><span class="cl">	<span class="k">using</span> <span class="n">MsgSender</span><span class="o">&lt;</span><span class="n">BaseName</span><span class="o">&gt;::</span><span class="n">do_foo2</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="nf">call_foo2</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">		<span class="n">do_foo2</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><ul>
<li>显式指定调用基类中的方法：</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="hl"><span class="lnt">5
</span></span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">BaseName</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="o">&lt;</span><span class="n">BaseName</span><span class="o">&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">call_foo2</span><span class="p">(){</span>
</span></span><span class="line hl"><span class="cl">		<span class="n">MsgSender</span><span class="o">&lt;</span><span class="n">BaseName</span><span class="o">&gt;::</span><span class="n">do_foo2</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>应该避免使用第三种解决方案，因为作用域限定符将会使得虚函数的动态绑定机制失效。</p>
<p>从名称可见性的角度看，这三个解决方案都做了一件事：向编译器保证 <code>do_foo2</code> 这个方法在任何模板特化中总是存在的。如果实际上不存在，那么在编译器该错误将被发现。</p>
<h3 id="item-44-factor-parameter-independent-code-out-of-templates">Item 44: Factor parameter-independent code out of templates.</h3>
<blockquote>
<p>✦ Templates generate multiple classes and multiple functions, so any template code not dependent on a template parameter causes bloat.</p>
<p>✦ Bloat due to non-type template parameters can often be eliminated by replacing template parameters with function parameters or class data members.</p>
<p>✦ Bloat due to type parameters can be reduced by sharing implemen- tations for instantiation types with identical binary representations.</p>
</blockquote>
<p>模板可以精简源码的大小，但也有可能在实例化的过程中增大生成的可执行文件的大小，一个原因就是模板被实例化次数过多了。例如，当我们想要实现一个支持转置的方阵：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="p">,</span> <span class="n">size_t</span> <span class="n">n</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SquareMatrix</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="n">invert</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>这个模板类有两个参数，一个类型参数 <code>T</code> 指示数据类型，一个非类型参数 <code>n</code> 指示矩阵大小。这是个常见的操作，但其会导致对于每个不同矩阵大小 <code>n</code>，即便数据类型相同，依旧会生成多份 <code>invert</code> 的实现代码。这显然是没必要的，需要再次进行抽象。</p>
<p>一种做法是抽象出一个模板基类，只接收一个类型参数 <code>T</code>，并提供一个 <code>void invert(size_T n)</code> 方法，让派生类将非类型参数转发到该方法。这样，相同类型的数据将共享相同的模板实例。</p>
<p>但是，抽象后的方案并不一定比原始方案更好。原始方案在编译期就确定了矩阵大小，编译器有更多的优化空间。另一方面，优化后的方案可执行文件更小，能够减少工作集的大小，提升程序的局部性，提升 cache 命中率。</p>
<h3 id="item-45-use-member-function-templates-to-accept-all-compatible-types">Item 45: Use member function templates to accept “all compatible types.”</h3>
<blockquote>
<p>✦ Use member function templates to generate functions that accept all compatible types.</p>
<p>✦ If you declare member templates for generalized copy construction or generalized assignment, you’ll still need to declare the normal copy constructor and copy assignment operator, too.</p>
</blockquote>
<p>假设我们想实现一个智能指针类 <code>SmartPointer</code>，要求支持从任何兼容的类型（任意类型的裸指针）构造：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SmartPointer</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">SmartPointer</span><span class="p">(</span><span class="n">T</span> <span class="o">*</span><span class="n">real_ptr</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来，要求不同类型的智能指针之间可以相互转换，可以使用通用复制构造函数 generalized copy constructors：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SmartPointer</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">U</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl">	<span class="n">SmartPointer</span><span class="p">(</span><span class="k">const</span> <span class="n">SmartPointer</span><span class="o">&lt;</span><span class="n">U</span><span class="o">&gt;&amp;</span> <span class="n">other</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码在模板类中使用了模板构造函数，以允许来自其它实例的构造参数。</p>
<p>接下来，要求这个智能指针能够像裸指针一样，支持隐式的类型转换，例如，派生类指针转换为基类指针，我们使用 cpp 内置的之间转换来实现：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SmartPointer</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">U</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl">	<span class="n">SmartPointer</span><span class="p">(</span><span class="k">const</span> <span class="n">SmartPointer</span><span class="o">&lt;</span><span class="n">U</span><span class="o">&gt;&amp;</span> <span class="n">other</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="o">:</span><span class="n">held_ptr</span><span class="p">(</span><span class="n">other</span><span class="p">.</span><span class="n">get</span><span class="p">();)</span> <span class="p">{};</span>
</span></span><span class="line"><span class="cl">	<span class="n">T</span><span class="o">*</span> <span class="nf">get</span><span class="p">()</span> <span class="k">const</span> <span class="p">{</span><span class="k">return</span> <span class="n">held_ptr</span><span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">T</span><span class="o">*</span> <span class="n">held_ptr</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>需要注意的是，模板构造函数并不会组织编译器生成默认构造函数。因此，如果需要拷贝构造一个对象，编译器将生成拷贝构造函数，即下述代码将没有任何输出：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;iostream&#34;</span><span class="cp">  
</span></span></span><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>  
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">SmartPointer</span><span class="p">{</span>  
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>  
</span></span><span class="line"><span class="cl">    <span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">U</span><span class="o">&gt;</span>  
</span></span><span class="line"><span class="cl">    <span class="n">SmartPointer</span><span class="p">(</span><span class="k">const</span> <span class="n">SmartPointer</span><span class="o">&lt;</span><span class="n">U</span><span class="o">&gt;&amp;</span> <span class="n">other</span><span class="p">)</span>  
</span></span><span class="line"><span class="cl">            <span class="o">:</span><span class="n">held_ptr</span><span class="p">(</span><span class="n">other</span><span class="p">.</span><span class="n">get</span><span class="p">())</span> <span class="p">{</span>  
</span></span><span class="line"><span class="cl">                <span class="n">std</span><span class="o">::</span><span class="n">cout</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;Enter template copy constructor</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">;</span>  
</span></span><span class="line"><span class="cl">            <span class="p">};</span>  
</span></span><span class="line"><span class="cl">    <span class="n">SmartPointer</span><span class="p">(</span><span class="n">T</span><span class="o">*</span> <span class="n">p</span><span class="p">)</span>  
</span></span><span class="line"><span class="cl">            <span class="o">:</span><span class="n">held_ptr</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="p">{};</span>  
</span></span><span class="line"><span class="cl">    <span class="n">T</span><span class="o">*</span> <span class="nf">get</span><span class="p">()</span> <span class="k">const</span> <span class="p">{</span><span class="k">return</span> <span class="n">held_ptr</span><span class="p">;};</span>  
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>  
</span></span><span class="line"><span class="cl">    <span class="n">T</span><span class="o">*</span> <span class="n">held_ptr</span><span class="p">;</span>  
</span></span><span class="line"><span class="cl"><span class="p">};</span>  
</span></span><span class="line"><span class="cl">  
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">main</span><span class="p">(){</span>  
</span></span><span class="line"><span class="cl">    <span class="n">SmartPointer</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">pint</span> <span class="o">=</span> <span class="p">{</span><span class="k">new</span> <span class="kt">int</span><span class="p">};</span>  
</span></span><span class="line"><span class="cl">    <span class="n">SmartPointer</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">pint2</span> <span class="o">=</span> <span class="p">{</span><span class="n">pint</span><span class="p">};</span> <span class="c1">// 调用默认拷贝构造函数，而非模板构造函数
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-46-define-non-member-functions-inside-templates-when-type-conversions-are-desired">Item 46: Define non-member functions inside templates when type conversions are desired.</h3>
<blockquote>
<p>✦ When writing a class template that offers functions related to the template that support implicit type conversions on all parameters, define those functions as friends inside the class template.</p>
</blockquote>
<p>在 <a href=".md#item-24-declare-non-member-functions-when-type-conversions-should-apply-to-all-parameters.">|Item 24</a> 中，我们使用非成员函数来实现支持交换律的加法（自动类型转换），当我们将该技巧一应用到模板上时，发生了一些微妙的变化：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Rational</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">Rational</span><span class="p">(</span><span class="k">const</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">numerator</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="k">const</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">denominator</span><span class="o">=</span><span class="mi">0</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">T</span> <span class="nf">numerator</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">T</span> <span class="nf">denominator</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">Rational</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span> <span class="k">operator</span><span class="o">*</span><span class="p">(</span><span class="k">const</span> <span class="n">Rational</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;&amp;</span> <span class="n">lhs</span><span class="p">,</span> <span class="k">const</span> <span class="n">Rational</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span> <span class="n">rhs</span><span class="p">){...}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="cm">/*******************/</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">Rational</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">half</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="n">Rational</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">res</span> <span class="o">=</span> <span class="n">half</span><span class="o">*</span><span class="mi">2</span><span class="p">;</span> <span class="c1">// error! won&#39;t compile!
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>怎会如此？！！原因在于，编译器要先对 <code>operator*</code> 进行实例化，但是，它不知道将该将 <code>T</code> 推导为哪个类型。<code>operator*</code> 接收了两个不同的参数类型，但在模板参数推导的过程中<strong>隐式类型转换</strong>不被考虑。</p>
<p>解决方案是，将 <code>operator*</code> 声明为友元函数，让其中模板类型参数随着类的实例化而一起实例化。需要注意的是，要在类中给出这个函数友元函数的定义，这个友元函数不是模板函数，在外部给出定义的函数是一个模板函数，<strong>二者不是一个函数</strong>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Rational</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">Rational</span><span class="p">(</span><span class="k">const</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">numerator</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="k">const</span> <span class="n">T</span><span class="o">&amp;</span> <span class="n">denominator</span><span class="o">=</span><span class="mi">0</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">T</span> <span class="nf">numerator</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">const</span> <span class="n">T</span> <span class="nf">denominator</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">	<span class="k">friend</span> <span class="k">const</span> <span class="n">Rational</span> <span class="k">operator</span><span class="o">*</span><span class="p">(</span><span class="k">const</span> <span class="n">Rational</span><span class="o">&amp;</span> <span class="n">lhs</span><span class="p">,</span> <span class="k">const</span> <span class="n">Rational</span> <span class="n">rhs</span><span class="p">)</span> <span class="c1">// 在模板类中可以简写，忽略尖括号中的内容
</span></span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">Rational</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span> <span class="k">operator</span><span class="o">*</span><span class="p">(</span><span class="k">const</span> <span class="n">Rational</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;&amp;</span> <span class="n">lhs</span><span class="p">,</span> <span class="k">const</span> <span class="n">Rational</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span> <span class="n">rhs</span><span class="p">){...}</span> <span class="c1">// 这是一个模板函数定义，类中声明的友元函数并非一个模板函数
</span></span></span></code></pre></td></tr></table>
</div>
</div><h3 id="item-47-use-traits-classes-for-information-about-types">Item 47: Use traits classes for information about types.</h3>
<blockquote>
<p>✦ Traits classes make information about types available during com- pilation. They’re implemented using templates and template special- izations.</p>
<p>✦ In conjunction with overloading, traits classes make it possible to perform compile-time if&hellip;else tests on types.</p>
</blockquote>
<p>我们来尝试实现 <code>advance</code> 模板函数，其作用是将一个指针或者迭代器移动指定距离。cpp 的所有迭代器中，有一部分支持随机访问，而另一部分仅支持连续访问。出于性能的考量，在实现 <code>advance</code> 时，我们需要分开实现这两种迭代器，即我们需要获知该迭代器的类型信息。由于我们还要支持对指针的操作，因此这一信息不应保存在迭代器的内部。这可如何是好？</p>
<p>好在，我们还有类型萃取 <code>traits</code>。<code>traits</code> 并非 cpp 中的关键字或者预定义的某个接口，它是一种技术的统称。鉴于类型信息不应保存在类型内部，标准做法是将其保存在一个模板，以及该模板的特化版本中。对于迭代器来说，标准库中的模板命名为 <code>iterator_traits</code>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">iterator_traits</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>传统上，使用结构体来实现 <code>traits</code>。通过在结构体内声明一个名为 <code>iterator_category</code> 的 typedef，对于不同类型的 <code>iterT</code> 定义不同的值，来区分不同的迭代器类型。</p>
<p>具体来说，<code>iterator_traits</code> 由两部分组成。对于用户定义的迭代器，要求其必须内嵌一个名为 <code>iterator_category</code> 的 typedef，取值为标准库中的迭代器的分类 tag：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="p">...</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">deque</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">class</span> <span class="nc">iterator</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">		<span class="k">typedef</span> <span class="n">random_access_iterator_tag</span> <span class="n">iterator_category</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="p">...</span>
</span></span><span class="line"><span class="cl">	<span class="p">};</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="cm">/*******迭代器tag取值**************/</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">input_iterator_tag</span> <span class="p">{};</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">output_iterator_tag</span> <span class="p">{};</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">forward_iterator_tag</span> <span class="o">:</span> <span class="k">public</span> <span class="n">input_iterator_tag</span> <span class="p">{};</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">bidirectional_iterator_tag</span> <span class="o">:</span> <span class="k">public</span> <span class="n">forward_iterator_tag</span> <span class="p">{};</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">random_access_iterator_tag</span> <span class="o">:</span> <span class="k">public</span> <span class="n">bidirectional_iterator_tag</span> <span class="p">{};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>例如，上文定义了一个双端队列中支持随机访问的迭代器。对于 <code>iterator_traits</code> 来说，其要做的就是将 <code>iterT</code> 中的 tag 再次声明为 <code>iterator_category</code>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">iterator_traits</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">typedef</span> <span class="k">typename</span> <span class="n">iterT</span><span class="o">::</span><span class="n">iterator_category</span> <span class="n">iterator_category</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>至此，我们已经完成了对用户自定义类型支持。接下来我们支持对内建指针的支持。指针是一种支持随机访问的迭代器，使用部分模板特化对指针进行特化：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">iterator_traits</span><span class="o">&lt;</span><span class="n">T</span><span class="o">*&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">typedef</span> <span class="k">typename</span> <span class="n">random_access_iterator_tag</span> <span class="n">iterator_category</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>当我们完成萃取类后，接下来就可以分类讨论来实现 <code>advance</code> 方法了：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="p">,</span> <span class="k">typename</span> <span class="n">distT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">advance</span><span class="p">(</span><span class="n">iterT</span><span class="o">&amp;</span> <span class="n">iter</span><span class="p">,</span> <span class="n">distT</span> <span class="n">d</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="k">typeid</span><span class="p">(</span><span class="k">typename</span> <span class="n">iterator_traits</span><span class="o">&lt;</span><span class="n">IterT</span><span class="o">&gt;::</span><span class="n">iterator_category</span><span class="p">)</span> <span class="o">==</span> <span class="k">typeid</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">random_access_iterator_tag</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>遗憾的是，上面的代码存在编译错误，该问题将在下一个条款讨论。不但如此，上述代码的 <code>if</code> 语句应该在运行期执行，但事实上，条件语句在编译器就已经确定了结果，这降低了代码的执行效率。</p>
<p>编译器的条件语句？emmmm&hellip;似乎比较麻烦。别忘了我们还有函数重载！函数重载的就是根据不同的参数类型执行不同的代码！据此，我们可以重载不同迭代器类型对应的 <code>advance</code> 实现：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="p">,</span> <span class="k">typename</span> <span class="n">distT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">do_advance</span><span class="p">(</span><span class="n">iterT</span><span class="o">&amp;</span> <span class="n">iter</span><span class="p">,</span> <span class="n">distT</span> <span class="n">d</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">random_access_iterator_tag</span><span class="p">){</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="p">,</span> <span class="k">typename</span> <span class="n">distT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">do_advance</span><span class="p">(</span><span class="n">iterT</span><span class="o">&amp;</span> <span class="n">iter</span><span class="p">,</span> <span class="n">distT</span> <span class="n">d</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">input_iterator_tag</span><span class="p">){</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>有了以上代码，<code>advance</code> 函数仅需要调用他们即可。需要注意的是，在模板函数的重载中，可以有未命名形参。但是在调用函数的过程中，必须传入实参对象。好在，我们前面用来标识类型的标签是空的结构体，我们可以直接使用该结构体构造一个空对象作为形参传入：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="p">,</span> <span class="k">typename</span> <span class="n">distT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">advance</span><span class="p">(</span><span class="n">iterT</span><span class="o">&amp;</span> <span class="n">iter</span><span class="p">,</span> <span class="n">distT</span> <span class="n">d</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="n">do_advance</span><span class="p">(</span><span class="n">iter</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="k">typename</span> <span class="n">std</span><span class="o">::</span><span class="n">iterator_traits</span><span class="o">&lt;</span><span class="n">iter</span><span class="o">&gt;::</span><span class="n">iterator_category</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>以上我们总结出使用函数萃取类的过程：</p>
<ul>
<li>创建一系列根据类型重载的“worker”函数。</li>
<li>创建一个“master”函数调用调用“worker”。</li>
</ul>
<h3 id="item-48-be-aware-of-template-metaprogramming">Item 48: Be aware of template metaprogramming.</h3>
<blockquote>
<p>✦ Template metaprogramming can shift work from runtime to com- pile-time, thus enabling earlier error detection and higher runtime performance.</p>
<p>✦ TMP can be used to generate custom code based on combinations of policy choices, and it can also be used to avoid generating code in- appropriate for particular types.</p>
</blockquote>
<p>模板元编程（TMP）就是书写在编译期运行的 cpp 代码的过程，其在编译期运行，输出结果再由编译器进行编译。</p>
<p>TMP 在上世纪九十年代被发现（⚠️不是发明），其有两个作用：</p>
<ul>
<li>让某些不可能或者难以实现的事情变得可以实现；</li>
<li>将一些运行期的工作转移到编译期进行。<br>
第二个作用，可以把一些运行期的错误提前到编译期发现，并减小编译生成的可执行代码的文件大小，提高行效率。</li>
</ul>
<p>前面我们在不使用函数重载实现 <code>advance</code> 的过程中，曾提到以下代码存在编译错误：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">iterT</span><span class="p">,</span> <span class="k">typename</span> <span class="n">distT</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">advance</span><span class="p">(</span><span class="n">iterT</span><span class="o">&amp;</span> <span class="n">iter</span><span class="p">,</span> <span class="n">distT</span> <span class="n">d</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="k">typeid</span><span class="p">(</span><span class="k">typename</span> <span class="n">iterator_traits</span><span class="o">&lt;</span><span class="n">IterT</span><span class="o">&gt;::</span><span class="n">iterator_category</span><span class="p">)</span> <span class="o">==</span> <span class="k">typeid</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">random_access_iterator_tag</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">	<span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>原因在于，如果我们传入一个不支持随机访问的迭代器，这个函数模板依旧会被完整展开，并且其中存在 <code>iter += d;</code> 这样的语句。尽管，该语句所在的 if 分支的条件永远为 <code>false</code>，这并不影响编译器对该语句进行编译检查。而不支持随机访问的迭代器并没有实现 <code>operator +=</code>，因此将会在编译期报错。</p>
<p>TMP 是图灵完备的，前一条款演示了在 TMP 中如何实现条件控制流。在 TMP 中，循环控制流则是通过递归来实现的，与常规 cpp 中递归调用函数不同，TMP 的递归是模板递归实例化。一个使用 TMP 编写的计算阶乘的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="kt">unsigned</span> <span class="n">n</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Factorial</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">enum</span> <span class="p">{</span><span class="n">value</span> <span class="o">=</span> <span class="n">n</span><span class="o">*</span><span class="n">Factorial</span><span class="o">&lt;</span><span class="n">n</span><span class="o">-</span><span class="mi">1</span><span class="o">&gt;::</span><span class="n">value</span><span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">Factorial</span><span class="o">&lt;</span><span class="mi">0</span><span class="o">&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">enum</span> <span class="p">{</span><span class="n">value</span> <span class="o">=</span> <span class="mi">1</span><span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>接下来作者举了 TMP 实际应用的几个领域，此处不在记录。总而言之，TMP 有其擅长的领域，但鉴于其反直觉的特性，以及作者写该书时相关工具链还很孱弱，需要谨慎使用。</p>
<h2 id="customizing-new-and-delete">Customizing new and delete</h2>
<p>现如今，很多语言都支持了自动垃圾回收。C++ 手动的回收方式似乎显得有些过时了。但是，许多系统的开发者选择 cpp，因为其允许他们手动管理内存。做到这一点，必须了解 cpp 中内存分配和释放例程的行为，这正是本章的重点内容。</p>
<p>在多线程环境下，内存管理的困难要大得多，因为堆和 new-handler 都是可修改的全局资源，容易受到竞争条件的影响。</p>
<h3 id="item-49-understand-the-behavior-of-the-new-handler">Item 49: Understand the behavior of the new-handler.、</h3>
<blockquote>
<p>✦ set_new_handler allows you to specify a function to be called when memory allocation requests cannot be satisfied.</p>
<p>✦ Nothrow new is of limited utility, because it applies only to memory allocation; associated constructor calls may still throw exceptions.</p>
</blockquote>
<p>如果 <code>operator new</code> 无法分配足够内存，其将抛出异常（老版本将返回 <code>NULL</code>），但在此之前，其将调用一个名为 <code>new-handler</code> 的错误处理函数。标准库中提供了一个 <code>set_new_handler</code> 函数用于设置 <code>new-handler</code>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">namespace</span> <span class="n">std</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">new_handler</span><span class="p">)();</span>
</span></span><span class="line"><span class="cl">	<span class="n">new_handler</span> <span class="nf">set_new_handler</span><span class="p">(</span><span class="n">new_handler</span> <span class="n">p</span><span class="p">)</span> <span class="k">noexcept</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>如上所示，<code>new_handler</code> 是一个输入参数和返回值均为空的函数指针类型，<code>set_new_handler</code> 接收这样一个指针，并将原处理函数指针返回。范例为：“</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">out_of_mem</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">cerr</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;Out of Mem</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">abort</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">main</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">set_new_handler</span><span class="p">(</span><span class="n">out_of_mem</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="o">*</span><span class="n">p</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">int</span><span class="p">[</span><span class="mi">100000000L</span><span class="p">];</span> <span class="c1">// if fail, call out_of_mem and then abort
</span></span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>当 <code>new</code> 不能分配足够的内存时，其将不停调用 new-handler 直至有足够内存，或者停止。因此，new-handler 函数必须满足以下特性之一：</p>
<ul>
<li>释放更多的内存空间。</li>
<li>设置另一个 new-handler 函数。</li>
<li>取消当前的 new-hander 函数。这将恢复 <code>new</code> 失败的默认行为，即抛出一个异常。</li>
<li>抛出异常。</li>
<li>不再返回，程序停止运行。</li>
</ul>
<p>如果我们想为不同的类定制 new-handler，似乎也挺简单的：在每次 new 之前手动替换对应的 new-handler 函数。接下来，我们一起来尝试将这一理念付诸实践。</p>
<p>首先，既然要替换原有的 new-handler 函数，那必须有一个变量在类中记录对应的 new-handler。那自然也要提供一个设置 new-handler 的接口，用于保存原始和替换 new-handler。此外，还需要重载 <code>operator new</code>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Widget</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">set_new_handler</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">new_hander</span> <span class="n">p</span><span class="p">)</span> <span class="k">noexcept</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span> <span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span> <span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">current_handler</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>需要注意，静态成员变量需要在类的定义外部进行定义和初始化。重载的 <code>new</code> 应该做哪些事情呢？如下：</p>
<ul>
<li>调用 <code>set_new_handler</code>，将 new-handler 设置为类提供的函数。</li>
<li>调用全局 <code>new</code> 实例化一个对象，如果失败，则应该恢复原始 new-handler 并抛出异常。为了确保其被正确恢复，应该使用资源管理类对 new-handler 进行管理。</li>
<li>如果 <code>new</code> 正常实例化了一个对象，则 new 应该返回对象指针。恢复 new-handler 的工作交由资源管理对象的析构函数负责。</li>
</ul>
<p>首先来实现一个 RAII 资源管理类：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">NewHandlerHolder</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">explicit</span> <span class="n">NewHandlerHolder</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">nh</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">	<span class="o">:</span><span class="n">handler</span><span class="p">(</span><span class="n">nh</span><span class="p">){};</span>
</span></span><span class="line"><span class="cl">	<span class="o">~</span><span class="n">NewHandlerHolder</span><span class="p">(){</span>
</span></span><span class="line"><span class="cl">		<span class="n">std</span><span class="o">::</span><span class="n">set_new_handler</span><span class="p">(</span><span class="n">handler</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">handler</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="n">NewHandlerHolder</span><span class="p">(</span><span class="k">const</span> <span class="n">NewHandlerHolder</span><span class="o">&amp;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="n">NewHandlerHolder</span><span class="o">&amp;</span> <span class="k">operator</span><span class="o">=</span><span class="p">(</span><span class="k">const</span> <span class="n">NewHandlerHolder</span><span class="o">&amp;</span><span class="p">);</span> <span class="c1">// 禁止拷贝构造和赋值
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>那么 <code>new</code> 可以重载为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span> <span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span> <span class="n">size</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="n">NewHandlerHolder</span> <span class="n">h</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">set_new_handler</span><span class="p">(</span><span class="n">current_handler</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="o">::</span><span class="k">operator</span> <span class="k">new</span><span class="p">(</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>看到这里，不禁感叹，真 TMD 优雅！屏住呼吸，还没结束呢！接下来，我们使用混合模式（Mixin-style），将其改造为模板类。详细来说，通过继承基类，派生类可以得到 <code>set_new_handler</code> 和 <code>operator new</code> 这俩成员，通过模板，则可以确保不同的类继承得到的静态成员是不同的。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">NewHandlerSupport</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">set_new_handler</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">new_hander</span> <span class="n">p</span><span class="p">)</span> <span class="k">noexcept</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span> <span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span> <span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="k">private</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">static</span> <span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">currentHandler</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">NewHandlerSupport</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="n">set_new_handler</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">new_hander</span> <span class="n">p</span><span class="p">)</span> <span class="k">noexcept</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">oldHandler</span> <span class="o">=</span> <span class="n">currentHandler</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">	<span class="n">currentHandler</span> <span class="o">=</span> <span class="n">p</span><span class="p">;</span> 
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">oldHandler</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span> 
</span></span><span class="line"><span class="cl"><span class="kt">void</span><span class="o">*</span> <span class="n">NewHandlerSupport</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="k">operator</span> <span class="k">new</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span> <span class="n">size</span><span class="p">){</span> 
</span></span><span class="line"><span class="cl">	<span class="n">NewHandlerHolder</span> <span class="nf">h</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">set_new_handler</span><span class="p">(</span><span class="n">currentHandler</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="o">::</span><span class="k">operator</span> <span class="k">new</span><span class="p">(</span><span class="n">size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">T</span><span class="o">&gt;</span> 
</span></span><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">new_handler</span> <span class="n">NewHandlerSupport</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;::</span><span class="n">currentHandler</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>有了模板类，我们再实现 <code>Wiget</code> 就简单多了：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Widget</span><span class="o">:</span> <span class="k">public</span> <span class="n">NewHandlerSupport</span><span class="o">&lt;</span><span class="n">Widget</span><span class="o">&gt;</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="p">...</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>值得注意的是，在模板类中，我们没有使用到参数类型 <code>T</code>，其存在的作用是为不同的类名，编译器都会创建一次代码副本，将他们的静态成员隔离开来。此外，我们的 <code>Widget</code> 继承了一个使用自己实例化的基类模板，这是合理的，这一技术的名字和它的行为一样古怪：奇异递归模板模式。</p>
<h3 id="item-50-understand-when-it-makes-sense-to-replace-new-and-delete">Item 50: Understand when it makes sense to replace new and delete.</h3>
<blockquote>
<p>✦ There are many valid reasons for writing custom versions of new and delete, including improving performance, debugging heap usage er- rors, and collecting heap usage information.</p>
</blockquote>
<p>为什么要替换编译器默认版本的 <code>new</code> 和 <code>delete</code> 运算符呢？一般来说，有以下三个理由：</p>
<ul>
<li>检查使用错误。没有或者多次释放 new 的内存，都会引发错误。如果在 <code>new</code> 和 <code>delete</code> 中维护一张内存申请表，则可以检查出上述问题。又或者，还可以用于防止数据溢出，通过在内存的末尾写入一个签名，在 <code>new</code> 中检查该是否完好，可以判断是否出现了数据溢出写入。</li>
<li>提高效率。编译器自带的实现版本，需要兼容各种程序、各种内存大小的申请的情况，还要考虑内存碎片等等各种情况，通过自定义实现，可以避免这些开销。</li>
<li>收集使用数据。在研发阶段通过收集数据，可以分析出该程序使用动态内存的特点，并针对性进行优化。</li>
</ul>
<p>接下来举个使用 <code>new</code> 检查内存是否存在溢出写入的例子：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">static</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">signature</span> <span class="o">=</span> <span class="mh">0xDEADBEEF</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="k">typedef</span> <span class="kt">unsigned</span> <span class="kt">char</span> <span class="n">Byte</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span> <span class="n">size</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="n">size_t</span> <span class="n">real_size</span> <span class="o">=</span> <span class="n">size</span><span class="o">+</span><span class="mi">2</span><span class="o">*</span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="kt">void</span> <span class="o">*</span><span class="n">p_mem</span> <span class="o">=</span> <span class="n">malloc</span><span class="p">(</span><span class="n">real_size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="o">!</span><span class="n">p_mem</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="k">throw</span> <span class="n">bad_alloc</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="o">*</span><span class="p">(</span><span class="k">static_cast</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">pMem</span><span class="p">))</span> <span class="o">=</span> <span class="n">signature</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="o">*</span><span class="p">(</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">*</span> <span class="o">&gt;</span><span class="p">(</span><span class="k">static_cast</span><span class="o">&lt;</span><span class="n">Byte</span><span class="o">*</span> <span class="o">&gt;</span><span class="p">(</span><span class="n">pMem</span><span class="p">)</span><span class="o">+</span><span class="n">realSize</span><span class="o">-</span><span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">)))</span> <span class="o">=</span> <span class="n">signature</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">Byte</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">pMem</span><span class="p">)</span> <span class="o">+</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">int</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码通过在申请的内存块两段放置额外的签名数据，以检测是否存在数据溢出写入的情况。当然，上述代码实际上存在很多问题。一方面，它不符合 cpp 关于 new 的规范，如果内存申请失败，应该循环调用 new-handler；另一方面，它没有考虑内存对齐的情况。</p>
<p>cpp 要求 <code>new</code> 返回的指针要满足内存对齐的要求，我们使用的 <code>malloc</code> 同样也是内存对齐的，但我们返回的偏移了一个 <code>int</code> 大小的指针，其不是对齐的。</p>
<p>内存对齐此类的小但确实重要的问题有很多很多，自定义一个完美的 <code>new</code> 的困难可见一斑。</p>
<h3 id="item-51-adhere-to-convention-when-writing-new-and-delete">Item 51: Adhere to convention when writing new and delete.</h3>
<blockquote>
<p>✦ operator new should contain an infinite loop trying to allocate mem- ory, should call the new-handler if it can’t satisfy a memory request, and should handle requests for zero bytes. Class-specific versions should handle requests for larger blocks than expected.</p>
<p>✦ operator delete should do nothing if passed a pointer that is null. Class-specific versions should handle blocks that are larger than ex- pected.</p>
</blockquote>
<p>这一条款将介绍在自定义 <code>new</code> 和 <code>delete</code> 时，需要遵守的几个规则。</p>
<p>首先是与 <code>new</code> 相关的几个要求：</p>
<ul>
<li>返回正确的值。</li>
<li>当内存不足时，循环调用 new-handle 函数。</li>
<li>正确处理申请大小为 0 的情况。</li>
<li>避免遮蔽正常的 <code>new</code>。此要求将在下一条款讨论。</li>
</ul>
<p>返回值听上去很简单，如果内存充足，则返回对应指针；否则，抛出异常。但也并非如此一蹴而就，如果内存不足，需要循环调用 new-handle 并再次申请内存，直至 new-handle 函数指针为空，抛出异常 <code>std::bad_alloc</code>。此外，cpp 规范还要求，即便申请了 0 字节大小的空间，也应该返回一个合法的指针。下面这段为代码，演示了一个 <code>new</code> 的行为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="o">*</span><span class="k">operator</span> <span class="nf">new</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span> <span class="n">size</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="n">size</span> <span class="o">=</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">while</span><span class="p">(</span><span class="nb">true</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="n">attemp</span> <span class="n">to</span> <span class="n">allocate</span> <span class="n">size</span> <span class="n">bytes</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">success</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="k">return</span> <span class="n">pointer</span> <span class="n">to</span> <span class="n">mem</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="n">new_handler</span> <span class="n">global_handler</span> <span class="o">=</span> <span class="n">set_new_handler</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span> <span class="c1">// 获取new-handler
</span></span></span><span class="line"><span class="cl">	<span class="n">set_new_handler</span><span class="p">(</span><span class="n">global_handler</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">global_handler</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">		<span class="p">(</span><span class="o">*</span><span class="n">global_handler</span><span class="p">)();</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">		<span class="k">throw</span> <span class="n">std</span><span class="o">::</span><span class="n">bad</span><span class="o">::</span><span class="n">alloc</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>对于 0 字节，可以把它当作申请了一个字节来处理。由于没有获取 <code>new-handler</code> 的函数，因此只能通过手动将其设置为 <code>null</code> 然后再恢复的方法，获取 <code>new-handler</code> 函数指针。对于多线程的环境，可能需要上锁防止竞争。</p>
<p>通常，为某个类重写的 <code>new</code> 都是针对这个类大小的内存进行优化的版本，而不是用于其他类或者该类的派生类。然而，如果在派生类中没有重写 <code>new</code>，<code>new</code> 派生类对象时将调用基类中的 <code>new</code> 函数。为了防止此类问题，可以在先判断 <code>size == sizeof(Base)</code>，若不相等，调用全局 <code>new</code> 函数。</p>
<p><code>delete</code> 函数就简单多了，唯一要注意的是：<code>delete</code> 要考虑指针为 NULL 的情况。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span> <span class="k">operator</span> <span class="nf">delete</span><span class="p">(</span><span class="kt">void</span> <span class="o">*</span><span class="n">rawMemory</span><span class="p">)</span> <span class="k">noexcept</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">	<span class="k">if</span><span class="p">(</span><span class="n">rawMemory</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">		<span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="err">归还已分配的内存</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>成员函数的版本也简单，只要像前面的 <code>new</code> 一样，记得校验申请的内存大小是否与基类大小一致即可。</p>
<h3 id="item-52-write-placement-delete-if-you-write-placement-new">Item 52: Write placement delete if you write placement new.</h3>
<blockquote>
<p>✦ When you write a placement version of operator new, be sure to write the corresponding placement version of operator delete. If you don’t, your program may experience subtle, intermittent memory leaks. 、</p>
<p>✦ When you declare placement versions of new and delete, be sure not to unintentionally hide the normal versions of those functions.</p>
</blockquote>
<p><code>Widget *pw = new Widget;</code> 这样一句代码会执行两个操作：先调用 <code>operator new</code> 申请对应大小的内存，再调用构造函数进行初始化。如果在构造期间出错了，由于构造没有完成，用户得不到 <code>pw</code> 指针，因此用户无法对初始化失败的内存进行释放。为了防止内存泄露，该操作由编译器负责。</p>
<p>编译器负责释放内存时，其必须知道与申请内存 <code>new</code> 配套的 <code>delete</code> 函数是哪个。对于常见的只接受一个参数 <code>size_t size</code> 的 <code>new</code> 来说，其配套的 <code>delete</code> 也是如此。但是，有一类 <code>new</code> 可以接受不止一个参数，这类 <code>new</code> 我们称之为“placement new”，定位 new。</p>
<p>placement new 狭义上只得是 <code>void* operator new(std::size_t, void *pMemory)</code>，其接受一个额外的指针，表示在其指示的位置构造对象。广义上来说，所有参数列表不止是 <code>size_t size</code> 的 new 都可以被称为 placement new。狭义的含义更常见，通过语境很容易判断 placement new 的含义。</p>
<p>对于 placement new，如果其在构造的过程中出错了，运行时系统负责找到参数类型和数量一致的 placement delete 释放对应内存。如果找不到，则会导致内存泄露。</p>
<p>如果 new 对象的过程一切整成，那么使用 <code>delete</code> 删除时，将会调用非 placement 版本。这就意味着，当自定义 placement new 时，既要提供 placement delete 版本防止构造失败，也要提供默认 delete 版本用于正常销毁。</p>
<p>由于名称遮蔽的存在，如果在类中声明了一个成员 placement new，其将遮蔽默认 new。此外，还遮蔽了全局存在三个版本的 new：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span><span class="p">,</span> <span class="kt">void</span><span class="o">*</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span><span class="o">*</span> <span class="k">operator</span> <span class="nf">new</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">size_t</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">nothrow_t</span><span class="p">);</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>如果遮蔽之后想让他们仍然可用，记得让配套的 delete 也可用。</p>
<h2 id="miscellany">Miscellany</h2>
<p>这一章是杂项，胜利在望！</p>
<h3 id="item-53-pay-attention-to-compiler-warnings">Item 53: Pay attention to compiler warnings.</h3>
<blockquote>
<p>✦ Take compiler warnings seriously, and strive to compile warning- free at the maximum warning level supported by your compilers.</p>
<p>✦ Don’t become dependent on compiler warnings, because different compilers warn about different things. Porting to a new compiler may eliminate warning messages you’ve come to rely on.</p>
</blockquote>
<p>大多数情况要，只要编译器没给出 error，程序都能跑起来。但在警告中，程序可能存在致命的错误，例如：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-cpp" data-lang="cpp"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">f</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Derived</span><span class="o">:</span> <span class="k">public</span> <span class="n">Base</span><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">	<span class="k">virtual</span> <span class="kt">void</span> <span class="n">f</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码本意是在派生类中重新定义激烈函数 <code>f</code>，但遗漏了 <code>const</code> 修饰符，对编译器而言这意味着 <code>Derived::f</code> 遮蔽了 <code>Base::f</code>。</p>
<p>对于此类行为，编译器会给出警告，根据警告，很容易检查出相应的错误。</p>
<h3 id="item-54-familiarize-yourself-with-the-standard-library-including-tr1">Item 54: Familiarize yourself with the standard library, including TR1.</h3>
<p>历史文件，现实意义尚不明确。跳过本条款。</p>
<h3 id="item-55-familiarize-yourself-with-boost">Item 55: Familiarize yourself with Boost.</h3>
<p>不想学了，草草结束，这个坑以后再填！</p>
<h2 id="终章">终章</h2>
<p>从 2024-04-17 到 2024-05-28，这本书耗费的时间比想象中多的多得多。anyway，收货还是颇丰的。阅读过程中时不时地会发出感叹：这也太细/优雅/牛逼了，作者很喜欢埋一些伏笔，读到后面恍然大悟，知识都串起来了。</p>
<p>作为第一本 cpp 深入的书籍，不错不错。缺点是对现代 cpp 的涉猎太少了，好在作者还有一本《Effective Mordern cpp》，安排上！</p>
<p>完结，撒花🎉</p>
]]></content:encoded>
    </item>
    <item>
      <title>我的大学四年</title>
      <link>https://www.zhouxin.space/thoughts/reflections-on-my-university-years/</link>
      <pubDate>Thu, 11 Apr 2024 14:29:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/thoughts/reflections-on-my-university-years/</guid>
      <description>&lt;h2 id=&#34;前言&#34;&gt;前言&lt;/h2&gt;
&lt;p&gt;时间过得真是快呀，转眼本科就要毕业了，切身体会到了白驹过隙的感觉：人生天地之间，若白驹之过隙，忽然而已。最近刚把博客搭得有点模样，趁着热情还在，写篇博文回顾我的大学四年。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h2 id="前言">前言</h2>
<p>时间过得真是快呀，转眼本科就要毕业了，切身体会到了白驹过隙的感觉：人生天地之间，若白驹之过隙，忽然而已。最近刚把博客搭得有点模样，趁着热情还在，写篇博文回顾我的大学四年。</p>
<p>一回想这四年，脑海里就会浮现出千头万绪，但总也厘不清，文章写的还是太少了，以后一定要多写点文章、多反思自己。</p>
<h2 id="入学前">入学前</h2>
<p>高考一战失利，选择复读。高复是个很纯粹的地方，大家心无旁念，每天就是学习，这样的日子过得很快。复读成绩 385 分，卡在一个很尴尬的位置：正好够到中上 211，但当时铁了心想学计算机，又怕进不了计算机专业。</p>
<p>在犹豫是选个好学校（苏州大学）还是去南邮读计算机时，我妈注意到了杭电。不得不说，杭电真的很会做宣传工作，当了解到杭电计算机评 B+ 和耀眼的 ACM 成绩，我第二天就确定要去杭电读计算机。</p>
<p>从这件事能看出，对于一些至关重要的人生抉择，一旦有了说服自己的理由，我其实不会再去收集信息、考虑其它方案，有一个可行解就足够了。这和我后来确定考研目标院校和临时直博的决策模式是一致的。为什么呢？我倾向于认为我是一个实干派，我更喜欢着手去做而非一直在制定目标和计划。</p>
<p>入学前给自己定的计划是：ACM 拿牌子、毕业进大厂。填志愿前就了解到，杭电就业不错，但保研率太低，建议直接就业。笑死，现在看这两个目标都没达成。没达成一方面是高估了自己的能动性，另一方面，眼界太小了，不知道大学中还有很多性价比更高的项目和竞赛。</p>
<h2 id="大一利见大人">大一：利见大人</h2>
<p>大一入学，一边在学算法想要加入 ACM 队，一边准备三个社团面试。虽然没能在 ACM 队中留下来，但大半年的刷题显著提升了我的编码能力，熟练掌握的算法也远超过培养计划的要求。</p>
<p>另一边，加入了院学生会，遇到了本科期间的两位贵人之一：韩某。她是我大一进入学生会时的部门负责人，刚进入大学，她提供了很多在大学怎么“卷”的经验：怎么选课、怎么卷绩点、怎么卷奖学金。初入大学，在《金榜题名之后》我可以被归类于直觉依赖型，即缺少宏观规划，继承了来自高中好好学习的习惯，只知道好好上课和学习。不过尽管如此，并不意味着好好学习就一定可以卷好绩点，不同老师之间给分差异很大；奖学金的评定也是如此，奖学金根据综评来，而综评的评价指标远不止绩点。这些经验都来自上一届的学长学姐，尽快获取这些经验，才能赶快适应大学生活，以正确的姿势开始走上卷绩点的路。</p>
<p>韩部长还帮我与另外一位贵人牵上了线：我的创新实践导师，李平。我大学四年取得的所有成果，除了数学建模，都是在他的指导下完成的。即便是数模，队友也是在他的实验室找到的。这是当时申请加入他的创新实践课程的邮件，总结了大一上的履历：</p>
<blockquote>
<p>李老师，您好。</p>
<p>&lt;个人信息&gt;。在阅读相关介绍后，对您机器学习相关方向较感兴趣，希望能够成为您小组的学员。</p>
<p>以下为我的个人介绍和相关经历：</p>
<p>大一上绩点为 4.4247。GPA 排名为 35/395，获得二等奖学金。</p>
<p>英语能力优秀，在大一上学期以 580 分通过 CET-4，大一上英语精读课程满绩，高考英语成绩为 104/120。</p>
<p>数学能力较为优秀。大一上学期，高等数学、线性代数的期中、期末卷面成绩均不低于 90 分。</p>
<p>编程能力扎实。自大一加入校 ACM 集训队，每天坚持写代码，目前总行数超过 2w3 且掌握了一些常用算法。</p>
<p>具备时间管理能力。大一上学期除去课内学习之外，还有院学生会的日常工作、ACM 集训队的训练和准备一些其它竞赛的任务。这些事情很好地锻炼了我的时间管理能力。</p>
<p>具有一定的学习能力。高考后的暑假和大一的第一个月，自学 C 语言和算法，成功通过 ACM 新生选拔赛并进入集训队。大一寒假，通过 5 天时间学习数学建模，参加美国大学生数学建模竞赛与交叉学科建模竞赛，几乎独自完成解决方案的提出、实现和论文的写作，并获得 H 奖。</p>
<p>我衷心希望能够加入您的小组，不知是否符合您的要求？</p>
<p>祝您生活愉快！</p>
</blockquote>
<p>创新实践第一节见面课，李老师在课上给我们介绍接下来两年的计划：新苗、大创、挑战杯、计设、服务外包等等，有条不紊。我当时就觉得：这个老师跟对了！</p>
<p>大一这一年，我白天很少待在宿舍。学生会 + ACM + 满课基本上填满了我的时间，当然天道还是酬勤的，换来了代码能力和学生会的社交圈子。</p>
<h2 id="大二终日乾乾">大二：终日乾乾</h2>
<p>大二是我最意气风发的一年，大多数成果都是在这一年产出的：新苗、数模、省奖、服务外包、学生会部长。</p>
<p>大二是真的忙呀，课表几乎满课，学生会部门的事务，还参加了数个竞赛。现在我已经回忆不起来当时的想法，但我应该没觉得喘不过气，毕竟这些都是我喜欢的活，并且每一件事的结果甚至都挺好。我一直都觉得自己潜力是无限的，但我真的需要一个客观条件来 push 我。我对很多事情都感兴趣，尤其是计算机，在钻研的过程中我可以感受到快乐和成就感，但这并不意味着他们就能直接吸引我去做，我还是更容易被一些其它简单的快乐引导。</p>
<p>大二这一年，基本上都是在和两位朋友一起共事。</p>
<p>一位是世另我（他的博客：<a href="https://www.albresky.cn/">Albresky&rsquo;s Blog</a>），来自同一个地区、同一所高中，在一次选修课上组队到一起才发现是老乡，更神奇的是两个人都有点极客风，小时候喜欢刷机、喜欢折腾，都对计算机充满了兴趣。大一下认识后迅速成为了好友，也撺掇他一起报了创新实践。后来我又跳到他的小组，一起做了三年项目和比赛。两个人真是太像了，互相都能迅速理解对方的意思并且预判对方的想法。</p>
<p>另一位是常年霸榜 GPA 第一的一位女生，是在创新实践课上认识的。“君子终日乾乾，夕惕若厉，无咎。”用来形容她再合适不过了，四年如一日地卷，没有懈怠。后面也跟她一起共事，也一起组队（还有另外一位室友）打了数模，发现她的聪明来自她的细心和踏实，她负责的事情总是面面俱到，能考虑到许多细节问题。这值得我和世另我学习，我们的聪明相比她更像是一种小聪明，或者说是机灵。她的眼界也比我俩更开阔一点。</p>
<p>和他俩共事真的很舒服，一位能迅速了解你的想法并跟你一起付诸实践，另外一位可以给你俩殿后，耐着性子一点点审查出你们想法的漏洞。笑死，有这样的团队我只怕导师不够 push。</p>
<h2 id="大三亢龙有悔">大三：亢龙有悔</h2>
<p>一旦失去了来自外部的 push，我就很容易产生惰性。大三上的半年即是如此，离开了学生会、大部分竞赛和项目都落下帷幕，再加上大三上几乎没课，让我空出了大量自由时间。如果理智一点，这段时间应该去实习或者学习一些技术，但让当时我的自主决定（甚至根本没有“决定”这个过程），那就是在宿舍玩王者。</p>
<p>到了大三下，没有实习、没学工程技术，手上有一些竞赛和项目，便自然而然走上了考研的道路。确定目标院校的过程和确定高考志愿的过程如出一辙：在知乎上看到了科软的帖子，很快就定下我要考科软。这次他吸引我的 30+ 的平均年薪 + 无导师的两年实习 +985 的 titile（当时还不了解 C9 华五），完美契合了本科入学前想进大厂的夙愿。</p>
<p>对于考研的回忆放到大四再说。从结果来看大三虚度的半年无伤大雅，选择科软亦是一个正确决定，大家都说一切都是最好的安排。但我觉得，幸运女神不会永远站在我这边，将命运交给运气太不安了，尤其是这件事原本可以掌握在你自己手里。仓促做出重大决定的这个习惯需要改掉，对于重大决定需要多多多收集信息。</p>
<h2 id="大四或跃在渊">大四：或跃在渊</h2>
<p>大四最重要的一件事：我中了，顺利上岸中科大！上岸的喜悦难以言表，查到分的那一刻甚至激动地喊了出来，对于情绪很少形于色的我来说确实激动到了极点。</p>
<p>不过，备考时期的我会觉得这是一件水到渠成的事情：我一直认为我的高考被江苏的 08 方案桎楛住了，考研考得四门都是我喜欢、我擅长的科目，我不上岸谁上岸？</p>
<p>备考过程倒是并不枯燥，我还挺享受的，毕竟回到了熟悉的学习 - 刷题 - 做卷子的模式。</p>
<p>考完我很害怕失败，我担心被评价为自大：有一种不祥的预感，我又一次败在了我的数学手上。实际上我也不是那么擅长数学。现在看来这是不必要的，即便是估分的下界也能进入复试。</p>
<p>查到的成绩比我估分的上限还高，这说明了另一个问题：我倾向于放大事件中的不确定因素，并倾向于认为这些不利的随机因素总是会发生的，我必须考虑最坏情况。这里我不觉得自己估分偏低是出于谦虚或者不自信，估分过程中能拿分的我都毫不犹豫给分，估分偏低是由于很多答案忘记而产生的不确定因素。事实上我现在也不觉得自己是个谦虚的人。</p>
<p>上岸后学习技术、想进大厂的欲望无比强烈，每天都来图书馆，一坐一整天，高效学习。最近在搭博客、写文章、学技术、做 Labs，非常充实。一方面，刚上岸之后激情还在；另一方面，查分之前突击学习了一周 Java 开发，真正接触到工程开发之后，这么趁手的工具满足了我对项目开发的一切幻想，成功激起了我学习的兴趣。</p>
<p>另外，大四上还发生了一件小插曲：我差点本校直博了。辅导员发通知的第二天，我决定直博；直博答辩的当天，一位导师联系了我；第二天确定他为博导；当天被学长劝退，辗转反侧；第三天和 GPA 第一个那个朋友交流，更加动摇；第四天和李平交流，放弃直博。</p>
<p>决定直博的过程又反应出我在做出人生关键决定时的仓促和草率，但这次我终于开始收集信息，怀疑这个决定的正确性，也要感谢那位劝退我的学长，他带着我找遍了实验室的其他人交流，让我产生了动摇的想法。</p>
<h2 id="总结">总结</h2>
<p>一路走来，很不容易。回顾了这四年，在即将盖棺定论的时刻，我也可以毫不心虚地说：我充实地度过了我的大学四年，这四年我没有荒废。</p>
]]></content:encoded>
    </item>
    <item>
      <title>基于Webhook实现hugo博客自动构建部署</title>
      <link>https://www.zhouxin.space/logs/hugo-auto-deployment-based-on-webhook/</link>
      <pubDate>Wed, 10 Apr 2024 10:35:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/logs/hugo-auto-deployment-based-on-webhook/</guid>
      <description>&lt;h1 id=&#34;博客发布流程&#34;&gt;博客发布流程&lt;/h1&gt;
&lt;p&gt;我的博文发布工作流可以参考这篇文章 &lt;a href=&#34;https://www.zhouxin.space/logs/blog-setup-logs/#%E5%8D%9A%E5%AE%A2%E5%8F%91%E5%B8%83%E5%B7%A5%E4%BD%9C%E6%B5%81&#34;&gt;博客搭建日志 &amp;gt; 博客发布工作流&lt;/a&gt;，其中最后两个步骤还需要手动完成，即登录服务器从 repo 中拉取，然后使用 hugo 命令构建。&lt;/p&gt;
&lt;p&gt;询问 GPT 后得知，Github 提供了 Webhook 服务，配合服务器上的 Webhook 监听器，可以实现每当我向 repo 推送时，都在服务器上自动拉取并构建博客。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="博客发布流程">博客发布流程</h1>
<p>我的博文发布工作流可以参考这篇文章 <a href="/logs/blog-setup-logs/#%E5%8D%9A%E5%AE%A2%E5%8F%91%E5%B8%83%E5%B7%A5%E4%BD%9C%E6%B5%81">博客搭建日志 &gt; 博客发布工作流</a>，其中最后两个步骤还需要手动完成，即登录服务器从 repo 中拉取，然后使用 hugo 命令构建。</p>
<p>询问 GPT 后得知，Github 提供了 Webhook 服务，配合服务器上的 Webhook 监听器，可以实现每当我向 repo 推送时，都在服务器上自动拉取并构建博客。</p>
<h1 id="步骤">步骤</h1>
<h2 id="前置条件">前置条件</h2>
<p>首先需要创建一个 repo 项目保存博客相关文件，并把服务器的公钥添加到 Github 账户的 SSH 密钥中。这一过程可以参考：<a href="/logs/blog-setup-logs/">博客搭建日志</a>。</p>
<h2 id="在服务器上设置-webhook-监听器">在服务器上设置 Webhook 监听器</h2>
<p>在服务器上需要设置一个监听器监听来自 Github 的 push 事件，可以自己用 Flask 写一个，或者直接用现成的 webhook 工具：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">sudo apt install webhook -y
</span></span></code></pre></td></tr></table>
</div>
</div><p>创建一个 webhook 文件配置文件 <code>hooks.json</code>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">[</span>
</span></span><span class="line"><span class="cl">  <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;id&#34;</span><span class="p">:</span> <span class="s2">&#34;redeploy-blog&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;execute-command&#34;</span><span class="p">:</span> <span class="s2">&#34;/path/to/your/script.sh&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;command-working-directory&#34;</span><span class="p">:</span> <span class="s2">&#34;/path/to/your/hugo/blog&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;pass-arguments-to-command&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">      <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;source&#34;</span><span class="p">:</span> <span class="s2">&#34;payload&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;head_commit.id&#34;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">],</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;trigger-rule&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;and&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">        <span class="p">{</span>
</span></span><span class="line"><span class="cl">          <span class="nt">&#34;match&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;payload-hash-sha1&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;secret&#34;</span><span class="p">:</span> <span class="err">，</span><span class="s2">&#34;your_webhook_secret&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;parameter&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">              <span class="nt">&#34;source&#34;</span><span class="p">:</span> <span class="s2">&#34;header&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">              <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;X-Hub-Signature&#34;</span>
</span></span><span class="line"><span class="cl">            <span class="p">}</span>
</span></span><span class="line"><span class="cl">          <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>hooks.json</code> 配置文件中，<code>execute-command</code> 指示监听到指定内容后需要执行的脚本，<code>command-working-directory</code> 指示了脚本的工作目录，可以设置为博客部署的目录。此外，还有一个 <code>secret</code> 字段需要修改为自定义内容，该字段用于验证报文是否来自 Github 发送。</p>
<p>然后创建待执行的脚本 <code>script.sh</code>，主要内容是进入指定目录、拉取最新更改、切换到 main 分支、执行构建命令：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl"><span class="cp">#!/bin/bash
</span></span></span><span class="line"><span class="cl"><span class="c1"># cd path/to/blog</span>
</span></span><span class="line"><span class="cl">git pull --all
</span></span><span class="line"><span class="cl">git switch main
</span></span><span class="line"><span class="cl">hugo
</span></span></code></pre></td></tr></table>
</div>
</div><p>为 <code>script.sh</code> 添加可执行权限：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">chmod +x /path/to/your/script.sh
</span></span></code></pre></td></tr></table>
</div>
</div><p>开放服务器指定端口（默认 9000），运行 webhook：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">webhook -hooks hooks.json -verbose --port <span class="m">9000</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在打印出的状态消息中，可以看到 webhook 正在监听的 url，后面需要填写到 Github 中。</p>
<h2 id="设置仓库-webhook">设置仓库 Webhook</h2>
<p>在你的 Github 对应的 repo 中：</p>
<ul>
<li>转到 &ldquo;Settings&rdquo; &gt; &ldquo;Webhooks&rdquo; &gt; &ldquo;Add webhook&rdquo;</li>
<li>在 <code>Payload url</code> 中填写服务器 webhook 监听的路径，注意将其中的 <code>{commits}</code> 替换为自定义内容；<code>Content type</code> 选择 <code>application/json</code>；<code>Secret</code> 填写与 <code>hooks.json</code> 一致的内容</li>
<li>添加完成后 Github 会向服务器发送一条 ping 消息，可以在服务器端和 Github Webhook 页面查看接受状态。如果接受失败，请检查：是否开放了服务器指定端口、url 直接使用浏览器访问服务器是否能接收到 get 请求、url 中若为 https 协议需要先配置反向代理。</li>
</ul>
<h2 id="使用-systemd-管理-webhook">使用 systemd 管理 webhook</h2>
<p>首先在服务器上创建 <code>systemd</code> 文件：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">sudo vim /etc/systemd/system/webhook.service
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后粘贴以下内容，注意修改命令中的端口号：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-ini" data-lang="ini"><span class="line"><span class="cl"><span class="k">[Unit]</span>
</span></span><span class="line"><span class="cl"><span class="na">Description</span><span class="o">=</span><span class="s">GitHub Webhook</span>
</span></span><span class="line"><span class="cl"><span class="na">After</span><span class="o">=</span><span class="s">network.target</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">[Service]</span>
</span></span><span class="line"><span class="cl"><span class="na">User</span><span class="o">=</span><span class="s">your_username</span>
</span></span><span class="line"><span class="cl"><span class="na">WorkingDirectory</span><span class="o">=</span><span class="s">/path/to/your/hugo/blog</span>
</span></span><span class="line"><span class="cl"><span class="na">ExecStart</span><span class="o">=</span><span class="s">/usr/bin/webhook -hooks /path/to/your/hooks.json -verbose --port xxxx</span>
</span></span><span class="line"><span class="cl"><span class="na">Restart</span><span class="o">=</span><span class="s">always</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">[Install]</span>
</span></span><span class="line"><span class="cl"><span class="na">WantedBy</span><span class="o">=</span><span class="s">multi-user.target</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>替换 <code>your_username</code> 为运行 webhook 的用户，<code>/path/to/your/hugo/blog</code> 和 <code>/path/to/your/hooks.json</code> 为实际的路径。</p>
<p>启用服务以确保它在每次启动时自动运行，并立即启动服务：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">sudo systemctl <span class="nb">enable</span> webhook.service
</span></span><span class="line"><span class="cl">sudo systemctl start webhook.service
</span></span></code></pre></td></tr></table>
</div>
</div><p>可以使用以下命令检查服务状态：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">sudo systemctl status webhook.service
</span></span></code></pre></td></tr></table>
</div>
</div>]]></content:encoded>
    </item>
    <item>
      <title>博客搭建日志</title>
      <link>https://www.zhouxin.space/logs/blog-setup-logs/</link>
      <pubDate>Tue, 02 Apr 2024 14:07:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/logs/blog-setup-logs/</guid>
      <description>&lt;h1 id=&#34;博客发布工作流&#34;&gt;博客发布工作流&lt;/h1&gt;
&lt;p&gt;本文介绍笔者在阿里云服务器上搭建个人博客的过程，目前我发布博客的工作流为：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;obsidian 编辑博客内容&lt;/li&gt;
&lt;li&gt;&lt;a href=&#34;https://github.com/platers/obsidian-linter&#34;&gt;obsidian-linter&lt;/a&gt; 插件对内容进行格式化&lt;/li&gt;
&lt;li&gt;&lt;a href=&#34;https://github.com/ObsidianPublisher/obsidian-github-publisher&#34;&gt;obsidian-github-publisher&lt;/a&gt; 插件对文档进行转换，并通过 PR 的形式合并到 &lt;a href=&#34;https://github.com/LittleHeroZZZX/zhouxin-space&#34;&gt;我的repo&lt;/a&gt; 中&lt;/li&gt;
&lt;li&gt;在服务器上通过 &lt;a href=&#34;https://git-scm.com/&#34;&gt;Git&lt;/a&gt; 拉取内容&lt;/li&gt;
&lt;li&gt;使用 &lt;a href=&#34;https://gohugo.io/&#34;&gt;Hugo&lt;/a&gt; 生成静态网页，并使用其自带 server 进行部署&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;本文重点介绍后两个步骤，即如何搭建一个基于 Hugo 的博客，以及自定义配置过程。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="博客发布工作流">博客发布工作流</h1>
<p>本文介绍笔者在阿里云服务器上搭建个人博客的过程，目前我发布博客的工作流为：</p>
<ul>
<li>obsidian 编辑博客内容</li>
<li><a href="https://github.com/platers/obsidian-linter">obsidian-linter</a> 插件对内容进行格式化</li>
<li><a href="https://github.com/ObsidianPublisher/obsidian-github-publisher">obsidian-github-publisher</a> 插件对文档进行转换，并通过 PR 的形式合并到 <a href="https://github.com/LittleHeroZZZX/zhouxin-space">我的repo</a> 中</li>
<li>在服务器上通过 <a href="https://git-scm.com/">Git</a> 拉取内容</li>
<li>使用 <a href="https://gohugo.io/">Hugo</a> 生成静态网页，并使用其自带 server 进行部署</li>
</ul>
<p>本文重点介绍后两个步骤，即如何搭建一个基于 Hugo 的博客，以及自定义配置过程。</p>
<h1 id="搭建过程">搭建过程</h1>
<h2 id="安装-hugo">安装 Hugo</h2>
<p>安装比较简单，参考 <a href="https://hugo.opendocs.io/installation/">安装 | Hugo官方文档</a>，需要注意的是不要使用 <code>apt</code> 安装，版本过低导致很多命令和主题不兼容。建议使用 <code>snap</code> 包管理器安装。</p>
<h2 id="新建项目">新建项目</h2>
<p>使用 <code>hugo new site &lt;your_site_name&gt; --format yaml</code> 创建一个名为 <code>&lt;your_site_name&gt;</code> 的网站，Hugo 会新建一个同名文件夹并初始化目录结构，基本结构包含以下目录：</p>
<ul>
<li>archetypes：存放 md 内容的模板文件</li>
<li>assets：存放将通过 Hugo 的 Pipes 功能进行处理的文件，如 SCSS 或 JavaScript 文件</li>
<li>content：存放网站内容文件，即每篇博文的 md 文件</li>
<li>data：存放一些配置文件</li>
<li>layouts：存放网站页面的模板文件</li>
<li>static：存放静态文件，例如图片等，这些文件在构建时会被复制到 <code>public</code> 目录中</li>
<li>i18n：存放翻译文件</li>
<li>themes：存放主题</li>
</ul>
<h2 id="应用主题">应用主题</h2>
<p>以 <a href="https://github.com/adityatelange/hugo-PaperMod">hugo-PaperMod</a> 主题为例，这里使用 Git 子模块的方式进行安装，更多安装方式见：<a href="https://adityatelange.github.io/hugo-PaperMod/posts/papermod/papermod-installation/">Install / Update PaperMod | PaperMod</a></p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl"><span class="nb">cd</span> &lt;your_site_name&gt;
</span></span><span class="line"><span class="cl">git init
</span></span><span class="line"><span class="cl">git submodule add --depth<span class="o">=</span><span class="m">1</span> https://github.com/adityatelange/hugo-PaperMod.git themes/PaperMod
</span></span><span class="line"><span class="cl">git submodule update --init --recursive <span class="c1"># needed when you reclone your repo (submodules may not get cloned automatically)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>修改配置文件 <code>hugo.yaml</code>，添加/修改 <code>themes</code> 字段为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-yaml" data-lang="yaml"><span class="line"><span class="cl"><span class="nt">theme</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="s2">&#34;PaperMod&#34;</span><span class="p">]</span><span class="w">
</span></span></span></code></pre></td></tr></table>
</div>
</div><h2 id="构建部署">构建部署</h2>
<p>运行如下命令，将 <code>&lt;your_ip/domain&gt;</code> 替换为公网 ip 或者域名或者 <code>127.0.0.1</code>（仅能在本机访问）：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">hugo server --bind<span class="o">=</span><span class="s2">&#34;0.0.0.0&#34;</span> --baseURL<span class="o">=</span><span class="s2">&#34;http://&lt;your_ip/domain&gt;&#34;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>然后打开 <code>http://&lt;your_ip/domain&gt;:1313</code> 就能看到博客了！🎉</p>
<h1 id="自定义配置">自定义配置</h1>
<h2 id="添加-archive">添加 Archive</h2>
<p>Archive 即博客中的归档、时间线功能，用于按照时间对博文分类管理。<br>
在 <code>&lt;your_web_site&gt;/content/</code> 目录下新建 <code>archive.md</code>，并添加以下模板内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-markdown" data-lang="markdown"><span class="line"><span class="cl">---
</span></span><span class="line"><span class="cl">title: &#34;Archive&#34;
</span></span><span class="line"><span class="cl">layout: &#34;archives&#34;
</span></span><span class="line"><span class="cl"><span class="gh"># url: &#34;/archives&#34;
</span></span></span><span class="line"><span class="cl">summary: &#34;archives&#34;
</span></span><span class="line"><span class="cl">---
</span></span></code></pre></td></tr></table>
</div>
</div><p>修改网站配置文件，添加一个归档的菜单，即可通过点击菜单栏上归档按钮进入时间线：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-yaml" data-lang="yaml"><span class="line"><span class="cl"><span class="nt">menu</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">main</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">    </span>- <span class="nt">name</span><span class="p">:</span><span class="w"> </span><span class="l">📦归档</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">      </span><span class="nt">url</span><span class="p">:</span><span class="w"> </span><span class="l">/archive</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">      </span><span class="nt">weight</span><span class="p">:</span><span class="w"> </span><span class="m">3</span><span class="w"> </span><span class="c"># 自定义权重，菜单按照权重从小到大的顺序排列</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="nt">defaultContentLanguage</span><span class="p">:</span><span class="w"> </span><span class="l">zh </span><span class="w"> </span><span class="c"># 修改默认语言为中文，在归档界面展示中文</span><span class="w">
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>注意，归档中只会对具有 <code>date</code> 字段的博文进行归档，如果归档页面为空，请检查该字段；如果看不到博文标题，请检查是否配置了 <code>title</code> 字段。</p>
<h2 id="添加搜索">添加搜索</h2>
<p>搜索也是 PaperMod 官方支持的模块，支持对博文内容、标题、关键字等进行索引。<br>
在 <code>&lt;your_web_site&gt;/content/</code> 目录下新建 <code>search.md</code>，并添加以下模板内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-markdown" data-lang="markdown"><span class="line"><span class="cl">---
</span></span><span class="line"><span class="cl">title: &#34;Search&#34; # in any language you want
</span></span><span class="line"><span class="cl">layout: &#34;search&#34; # necessary for search
</span></span><span class="line"><span class="cl"><span class="gh"># url: &#34;/archive&#34;
</span></span></span><span class="line"><span class="cl"><span class="gh"># description: &#34;Description for Search&#34;
</span></span></span><span class="line"><span class="cl">summary: &#34;search&#34;
</span></span><span class="line"><span class="cl">placeholder: &#34;支持搜索标题、博文、Tags等&#34;
</span></span><span class="line"><span class="cl">---
</span></span></code></pre></td></tr></table>
</div>
</div><p>模板中的 <code>placeholder</code> 字段为搜索框的默认展示内容，可以自定义修改。</p>
<p>修改网站配置文件，添加启用搜索所需的配置信息和搜索的菜单按钮：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-yaml" data-lang="yaml"><span class="line"><span class="cl"><span class="c"># 启用搜索所需的信息</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="nt">outputs</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">home</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">    </span>- <span class="l">HTML</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">    </span>- <span class="l">RSS</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">    </span>- <span class="l">JSON</span><span class="w"> </span><span class="c"># necessary for search</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="c"># 搜索的菜单按钮</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="nt">menu</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">main</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">    </span>- <span class="nt">name</span><span class="p">:</span><span class="w"> </span><span class="l">🔍搜索</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">      </span><span class="nt">url</span><span class="p">:</span><span class="w"> </span><span class="l">/search</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">      </span><span class="nt">weight</span><span class="p">:</span><span class="w"> </span><span class="m">1</span><span class="w">
</span></span></span></code></pre></td></tr></table>
</div>
</div><h2 id="添加-tags">添加 Tags</h2>
<p>Tags 属于 PaperMod 已经默认实现的一个页面，只要在菜单中添加一个指向 <code>/tags</code> 的按钮即可：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-yaml" data-lang="yaml"><span class="line"><span class="cl"><span class="nt">menu</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="nt">main</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">    </span>- <span class="nt">name</span><span class="p">:</span><span class="w"> </span><span class="l">🏷️标签</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">      </span><span class="nt">url</span><span class="p">:</span><span class="w"> </span><span class="l">/tags</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">      </span><span class="nt">weight</span><span class="p">:</span><span class="w"> </span><span class="m">2</span><span class="w">
</span></span></span></code></pre></td></tr></table>
</div>
</div><h2 id="访问量统计">访问量统计</h2>
<p>使用 <a href="https://busuanzi.ibruce.info/">不蒜子 - 极简网页计数器</a> 对博客和文章访问量进行统计。<br>
首先在 <code>&lt;your_web_site&gt;/layouts/partials/extend_head.html</code> 文件中添加以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl"><span class="c">&lt;!-- busuanzi --&gt;</span>
</span></span><span class="line"><span class="cl">{{- if .Site.Params.busuanzi.enable -}}
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">script</span> <span class="na">async</span> <span class="na">src</span><span class="o">=</span><span class="s">&#34;//busuanzi.ibruce.info/busuanzi/2.3/busuanzi.pure.mini.js&#34;</span><span class="p">&gt;&lt;/</span><span class="nt">script</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">meta</span> <span class="na">name</span><span class="o">=</span><span class="s">&#34;referrer&#34;</span> <span class="na">content</span><span class="o">=</span><span class="s">&#34;no-referrer-when-downgrade&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">style</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    <span class="p">:</span><span class="nd">root</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nv">--footer-height</span><span class="p">:</span> <span class="mi">80</span><span class="kt">px</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">style</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">{{- end -}}
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>extend_head.html</code> 中的内容会被包含在 <code>&lt;head&gt;</code> 中，因此我们可以在这个文件中引入不蒜子需要的 js 文件。此外，我们还定义（实际上是覆盖）了一个 PaperMod 主题中知识页脚高度的变量，以防止由于多行页脚导致内容溢出一屏。</p>
<p>然后打开配置文件 <code>&lt;your_web_site&gt;/hugo.yaml</code>，增加以下字段以启用计数模块：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-yaml" data-lang="yaml"><span class="line"><span class="cl"><span class="nt">params</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">	</span><span class="nt">busuanzi</span><span class="p">:</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">	    </span><span class="nt">enable</span><span class="p">:</span><span class="w"> </span><span class="kc">true</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">	</span><span class="nt">hideFooter</span><span class="p">:</span><span class="w"> </span><span class="kc">true</span><span class="w">  </span><span class="c"># 禁用默认页脚</span><span class="w">
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>由于 PaperMod 主题默认的页脚样式对于多行页脚支持不完善，因此上述配置还禁用了主题默认的页脚模块。我们将在 <code>&lt;your_web_site&gt;/layouts/partials/extend_footer.html</code> 重写页脚，即向其中添加以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl">{{- if not (.Param &#34;hideCustumFooter&#34;) }}
</span></span><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">footer</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;footer&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  {{- if site.Copyright }}
</span></span><span class="line"><span class="cl">  <span class="p">&lt;</span><span class="nt">span</span><span class="p">&gt;</span>{{ site.Copyright | markdownify }}<span class="p">&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  {{- else }}
</span></span><span class="line"><span class="cl">  <span class="p">&lt;</span><span class="nt">span</span><span class="p">&gt;</span><span class="ni">&amp;copy;</span> {{ now.Year }} <span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;{{ &#34;</span><span class="err">&#34;</span> <span class="err">|</span> <span class="na">absLangURL</span> <span class="err">}}&#34;</span><span class="p">&gt;</span>{{ site.Title }}<span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  {{- end }}
</span></span><span class="line"><span class="cl">  <span class="p">&lt;</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    Powered by
</span></span><span class="line"><span class="cl">    <span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;https://gohugo.io/&#34;</span> <span class="na">rel</span><span class="o">=</span><span class="s">&#34;noopener noreferrer&#34;</span> <span class="na">target</span><span class="o">=</span><span class="s">&#34;_blank&#34;</span><span class="p">&gt;</span>Hugo<span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;</span> <span class="err">&amp;</span>
</span></span><span class="line"><span class="cl">    <span class="p">&lt;</span><span class="nt">a</span> <span class="na">href</span><span class="o">=</span><span class="s">&#34;https://github.com/adityatelange/hugo-PaperMod/&#34;</span> <span class="na">rel</span><span class="o">=</span><span class="s">&#34;noopener&#34;</span> <span class="na">target</span><span class="o">=</span><span class="s">&#34;_blank&#34;</span><span class="p">&gt;</span>PaperMod<span class="p">&lt;/</span><span class="nt">a</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  <span class="p">&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  {{ if .Site.Params.busuanzi.enable -}}
</span></span><span class="line"><span class="cl">  <span class="p">&lt;</span><span class="nt">div</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  <span class="p">&lt;</span><span class="nt">span</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;busuanzi_container_site_pv&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">    本站总访问量<span class="p">&lt;</span><span class="nt">span</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;busuanzi_value_site_pv&#34;</span><span class="p">&gt;&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>次
</span></span><span class="line"><span class="cl">  <span class="p">&lt;/</span><span class="nt">span</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">div</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">  {{- end -}}
</span></span><span class="line"><span class="cl"><span class="p">&lt;/</span><span class="nt">footer</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">{{- end }}
</span></span></code></pre></td></tr></table>
</div>
</div><p>上述代码依次添加了版权声明、powered by 和访问量的页脚内容，可根据自己喜好调整，或者添加 IPC 备案号等内容。</p>
<p>接下来添加单篇文章的阅读量。PaperMod 主题中文章都是基于 <code>single.html</code> 这个文件渲染的，因此我们接下来要修改该文件。为了防止对主题文件破坏，我们将 <code>&lt;your_web_site&gt;/themes/PaperMod/layouts/_default/single.html</code> 拷贝到 <code>&lt;your_web_site&gt;/layouts/_default/single.html</code>，并在此文件进行修改（用户目录的文件优先级高于主题目录）。</p>
<p>找到 <code>&lt;div class=&quot;post-meta&quot;&gt;</code>，这个 <code>div</code> 包含了一篇文章所有的 meta 数据，在其中添加一个表示阅读量 <code>div</code> 即可，即修改后的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-html" data-lang="html"><span class="line"><span class="cl"><span class="p">&lt;</span><span class="nt">div</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;post-meta&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">      {{- partial &#34;post_meta.html&#34; . -}}
</span></span><span class="line"><span class="cl">      {{- partial &#34;translation_list.html&#34; . -}}
</span></span><span class="line"><span class="cl">      {{- partial &#34;edit_post.html&#34; . -}}
</span></span><span class="line"><span class="cl">      {{- partial &#34;post_canonical.html&#34; . -}}
</span></span><span class="line"><span class="cl">      {{ if .Site.Params.busuanzi.enable -}}
</span></span><span class="line"><span class="cl">      <span class="p">&lt;</span><span class="nt">div</span> <span class="na">class</span><span class="o">=</span><span class="s">&#34;meta-item&#34;</span><span class="p">&gt;</span><span class="err">&amp;</span>nbsp·<span class="err">&amp;</span>nbsp
</span></span><span class="line"><span class="cl">        阅读量 <span class="p">&lt;</span><span class="nt">span</span> <span class="na">id</span><span class="o">=</span><span class="s">&#34;busuanzi_value_page_pv&#34;</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">      <span class="p">&lt;/</span><span class="nt">div</span><span class="p">&gt;</span>
</span></span><span class="line"><span class="cl">      {{- end }}
</span></span><span class="line"><span class="cl">    <span class="p">&lt;/</span><span class="nt">div</span><span class="p">&gt;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>之后就可以在文章页面看到文章阅读量了。</p>
<h1 id="参考文档">参考文档</h1>
]]></content:encoded>
    </item>
    <item>
      <title>Blog更新日志</title>
      <link>https://www.zhouxin.space/logs/blog-logs/</link>
      <pubDate>Tue, 02 Apr 2024 12:26:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/logs/blog-logs/</guid>
      <description>&lt;h2 id=&#34;2024-10-17&#34;&gt;2024-10-17&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;友链上线&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-07-08&#34;&gt;2024-07-08&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;实现侧边目录&lt;/li&gt;
&lt;li&gt;TOP 按钮增加阅读进度展示&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-06-20&#34;&gt;2024-06-20&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;引入谷歌分析&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-06-07&#34;&gt;2024-06-07&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;完善公式渲染&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-04-19&#34;&gt;2024-04-19&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;修改代码块样式&lt;/li&gt;
&lt;li&gt;引入 giscus 评论系统&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-04-12&#34;&gt;2024-04-12&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;修改 dir 和 slug 为英文&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-04-08&#34;&gt;2024-04-08&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;IPC 备案完毕，启用域名 &lt;a href=&#34;https://www.zhouxin.space/&#34;&gt;zhouxin.space&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;添加全站访问量统计和博文阅读量统计&lt;/li&gt;
&lt;li&gt;修改 404 页面&lt;/li&gt;
&lt;li&gt;配置 https&lt;/li&gt;
&lt;li&gt;使用 nginx 反向代理&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&#34;2024-04-03&#34;&gt;2024-04-03&lt;/h2&gt;
&lt;p&gt;本站上线&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h2 id="2024-10-17">2024-10-17</h2>
<ul>
<li>友链上线</li>
</ul>
<h2 id="2024-07-08">2024-07-08</h2>
<ul>
<li>实现侧边目录</li>
<li>TOP 按钮增加阅读进度展示</li>
</ul>
<h2 id="2024-06-20">2024-06-20</h2>
<ul>
<li>引入谷歌分析</li>
</ul>
<h2 id="2024-06-07">2024-06-07</h2>
<ul>
<li>完善公式渲染</li>
</ul>
<h2 id="2024-04-19">2024-04-19</h2>
<ul>
<li>修改代码块样式</li>
<li>引入 giscus 评论系统</li>
</ul>
<h2 id="2024-04-12">2024-04-12</h2>
<ul>
<li>修改 dir 和 slug 为英文</li>
</ul>
<h2 id="2024-04-08">2024-04-08</h2>
<ul>
<li>IPC 备案完毕，启用域名 <a href="https://www.zhouxin.space/">zhouxin.space</a></li>
<li>添加全站访问量统计和博文阅读量统计</li>
<li>修改 404 页面</li>
<li>配置 https</li>
<li>使用 nginx 反向代理</li>
</ul>
<h2 id="2024-04-03">2024-04-03</h2>
<p>本站上线</p>
<ul>
<li>基于 Hugo 构建和 PaperMod 主题</li>
<li>实现菜单中的搜索、标签和归档</li>
<li>图片资源保存在阿里云 OSS</li>
<li>修改字体为霞鹜文楷</li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>创建基于阿里云OSS的图床</title>
      <link>https://www.zhouxin.space/logs/image-hosting-based-on-aliyun-oss/</link>
      <pubDate>Mon, 01 Apr 2024 23:10:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/logs/image-hosting-based-on-aliyun-oss/</guid>
      <description>&lt;h1 id=&#34;概述&#34;&gt;概述&lt;/h1&gt;
&lt;p&gt;最近在研究怎么使用 hugo 发布 obsidian 文档，对于图片和其他等附件的保存位置，有两种方案：直接保存到博客服务器或者保存到图床。考虑到服务器只买了 30GB 的硬盘，直接放服务器上可能会爆容量，还是选择基于阿里云的 OSS 服务搭建图床，也保留了以后使用阿里云 CDN 服务加速访问的可能性。&lt;br&gt;
本文参考整合了网络上多篇博客教程 &lt;sup id=&#34;fnref:1&#34;&gt;&lt;a href=&#34;#fn:1&#34; class=&#34;footnote-ref&#34; role=&#34;doc-noteref&#34;&gt;1&lt;/a&gt;&lt;/sup&gt;。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="概述">概述</h1>
<p>最近在研究怎么使用 hugo 发布 obsidian 文档，对于图片和其他等附件的保存位置，有两种方案：直接保存到博客服务器或者保存到图床。考虑到服务器只买了 30GB 的硬盘，直接放服务器上可能会爆容量，还是选择基于阿里云的 OSS 服务搭建图床，也保留了以后使用阿里云 CDN 服务加速访问的可能性。<br>
本文参考整合了网络上多篇博客教程 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>。</p>
<h1 id="图床搭建">图床搭建</h1>
<p>图床搭建主要由三部分组成：购买阿里云 OSS 空间，创建存储空间 Bucket，绑定域名（可选），配置安全策略，配置图床插件。<br>
图床的费用由两部分构成：存储费用（40GB 9 元/年）和流量费用（0.5 元/GB)，正常情况下流量费用可以忽略不计。</p>
<h2 id="购买-oss-空间和创建-bucket">购买 OSS 空间和创建 Bucket</h2>
<p>在阿里云官网搜索 OSS 即可找到购买页面，按照如下配置购买即可：<br>
<img alt="OSS购买配置截图" loading="lazy" src="https://pics-zhouxin.oss-cn-hangzhou.aliyuncs.com/OSS%E8%B4%AD%E4%B9%B0%E9%85%8D%E7%BD%AE%E6%88%AA%E5%9B%BE.png"><br>
购买完成后，在 OSS 管理页面可以创建 Bucket，按照如下配置进行设置：<br>
<img alt="Bucket创建配置" loading="lazy" src="https://pics-zhouxin.oss-cn-hangzhou.aliyuncs.com/Bucket%E5%88%9B%E5%BB%BA%E9%85%8D%E7%BD%AE.png"></p>
<h2 id="绑定域名">绑定域名</h2>
<p>#todo</p>
<h2 id="数据安全---防盗链">数据安全 - 防盗链</h2>
<p>为了防止图床图片被第三方引用导致异常的流量费用，可以使用 OSS 提供的防盗链功能仅对白名单 <code>Referer</code> 内的请求响应，设置路径在 <code>Bucket控制台-数据安全-防盗链</code>，在白名单中保留允许访问的域名如：<code>*.aliyun.com</code>、<code>blog.example.com</code>。<br>
是否允许空 Referer 访问仁者见仁，如果禁止将导致在 obsidian、typora 等软件中无法正常加载 OSS 上的图片。</p>
<h2 id="配置图床插件">配置图床插件</h2>
<p>首先在阿里云中为 PicGo 创建一个子用户，并授予其对 OSS 的完全管理权限。<br>
创建子用户：在阿里云中找到 <code>RAM访问控制-身份管理-用户-创建用户</code>，登录名称任意，勾选允许 <code>OpenAPI 调用访问</code>，创建完成后会得到一组 <code>AccessKey ID</code> 和 <code>AccessKey Secret</code>，需要保管好，后续会用到。<br>
然后在用户管理界面，为刚刚创建的用户添加权限 <code>AliyunOSSFullAccess</code>。<br>
<img alt="Pasted image 20240402090526" loading="lazy" src="https://pics-zhouxin.oss-cn-hangzhou.aliyuncs.com/Pasted%20image%2020240402090526.png"><br>
<img alt="Pasted image 20240402090610" loading="lazy" src="https://pics-zhouxin.oss-cn-hangzhou.aliyuncs.com/Pasted%20image%2020240402090610.png"></p>
<p>子账户配置完成后，在终端使用命令 <code>winget install picgo</code> 安装图床软件 PicGo，或者前往 <a href="https://picgo.github.io/PicGo-Doc/zh/guide/#%E4%B8%8B%E8%BD%BD%E5%AE%89%E8%A3%85">PicGo is Here | PicGo</a> 下载。在 <code>PicGo-图床设置-阿里云OOS</code> 配置相应参数：<br>
设定 KeyId：子用户的 AccessKey ID<br>
设定 KeySecret：子用户对应的 AccessKey Secre<br>
设定 Bucket：之前创建的 Bucket 名称</p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://zhuanlan.zhihu.com/p/638165744">02.Hugo中使用阿里云OSS作为图床 - 知乎</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>安装并切换指定gcc或者g&#43;&#43;版本</title>
      <link>https://www.zhouxin.space/notes/install-and-switch-to-specific-version-of-gcc/</link>
      <pubDate>Mon, 01 Apr 2024 10:58:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/install-and-switch-to-specific-version-of-gcc/</guid>
      <description>&lt;h1 id=&#34;知其然&#34;&gt;知其然&lt;/h1&gt;
&lt;p&gt;&lt;strong&gt;注意：&lt;/strong&gt; 该方式将从 PPA 下载 gcc/g++，国内访问很慢，建议参考 &lt;a href=&#34;https://www.zhouxin.space/notes/config-proxy-for-apt/&#34;&gt;《为apt配置代理》&lt;/a&gt; 这篇文章，配置好 apt 的代理。&lt;br&gt;
以安装 &lt;code&gt;g++ 13&lt;/code&gt; 版本（不支持指定小版本号）为例，以下给出用到的命令 &lt;sup id=&#34;fnref:1&#34;&gt;&lt;a href=&#34;#fn:1&#34; class=&#34;footnote-ref&#34; role=&#34;doc-noteref&#34;&gt;1&lt;/a&gt;&lt;/sup&gt;：&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="知其然">知其然</h1>
<p><strong>注意：</strong> 该方式将从 PPA 下载 gcc/g++，国内访问很慢，建议参考 <a href="/notes/config-proxy-for-apt/">《为apt配置代理》</a> 这篇文章，配置好 apt 的代理。<br>
以安装 <code>g++ 13</code> 版本（不支持指定小版本号）为例，以下给出用到的命令 <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">sudo apt update
</span></span><span class="line"><span class="cl">sudo apt install software-properties-common -y
</span></span><span class="line"><span class="cl">sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y <span class="o">&amp;&amp;</span> sudo apt update
</span></span><span class="line"><span class="cl">sudo apt install gcc-13 g++-13 -y
</span></span><span class="line"><span class="cl">sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-13 <span class="m">13</span> --slave /usr/bin/g++ g++ /usr/bin/g++-13
</span></span><span class="line"><span class="cl">sudo update-alternatives --config gcc
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意，上面第四条指令中的 gcc/g++ 后面的版本号需要根据自己的需要修改。以后一条指令用于可视化调整 gcc 各个版本的优先级。</p>
<h1 id="知其所以然">知其所以然</h1>
<p>上述过程可以理解为：</p>
<ol>
<li>添加 PPA 源</li>
<li>安装指定版本的 gcc</li>
<li>使用 <code>update-alternatives</code> 工具调整优先级，使得 <code>gcc</code> 默认指向 <code>gcc-13</code></li>
</ol>
<h2 id="ppa-源">PPA 源</h2>
<p>PPA 指的是 Personal Package Archive，即个人软件包存档，其是相对官方仓库的一个概念。Ubuntu 提供了一个官方软件仓库以及该仓库的镜像仓库，该仓库会进行兼容性检查，因此更新较慢 <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>。<br>
为此，引入了 PPA，即让开发人员自己搭建的非官方软件仓库，以此获取最新的软件版本。<br>
在这里，为了安装 <code>gcc 13</code>，我们使用 <code>add-apt-repository</code> 命令添加 ppa 仓库 <code>ppa:ubuntu-toolchain-r/test</code>。在此之前我们还安装了 <code>software-properties-common</code> 工具，以确保正确使用 <code>add-apt-repository</code> 命令。</p>
<h2 id="安装-gcc">安装 gcc</h2>
<p>添加 PPA 仓库之后，就可以使用 <code>apt</code> 命令正常安装 <code>gcc</code>，这里我们使用 <code>gcc-13</code> 来指定版本号。注意，只能指定大版本号，该方式不支持指定小版本号。</p>
<h2 id="update-alternatives-调整优先级">update-alternatives 调整优先级</h2>
<p><code>update-alternatives</code> 是 Ubuntu 提供的一个维护符号链接的工具，其通过更新符号链接来实现程序在多个版本之间的切换。其使用“替代方案”这一概念，一个替代方案指的是一组可以相互替代的命令，例如 <code>gcc-10</code> 和 <code>gcc-12</code> 就是 <code>gcc</code> 的替代方案。添加替代方案的命令为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">update-alternatives --install &lt;link&gt; &lt;name&gt; &lt;path&gt; &lt;priority&gt;
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>link</code> 指的是将被创建或者更新的符号链接的地址，例如 <code>/usr/bin/gcc</code>；<br>
<code>name</code> 指的是替代方案的标识名称，例如 <code>gcc</code>；<br>
<code>path</code> 指的是符号链接指向的在替代方案中希望使用的具体程序版本或者实现，例如 <code>/usr/bin/gcc-12</code>；<br>
<code>prioritity</code> 指的是该 <code>path</code> 在方案中的优先级，是整数，优先级越高数字越大，在本例中我们根据 <code>gcc</code> 版本号给定相应的优先级。</p>
<p>你可能注意到了，我们实际使用的命令是 <code>sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-13 13 --slave /usr/bin/g++ g++ /usr/bin/g++-13</code>，后半部分还有一个参数 <code>--slave /usr/bin/g++ g++ /usr/bin/g++-13</code>，这个命令的作用是为主方案添加多个从属方案，即当我们切换 <code>gcc</code> 时，自动切换相对应的从属方案 <code>g++</code>，其语法是：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">update-alternatives  --install &lt;link&gt; &lt;name&gt; &lt;path&gt; &lt;priority&gt; <span class="o">[</span>--slave &lt;link&gt; &lt;name&gt; &lt;path&gt;<span class="o">]</span> ...
</span></span></code></pre></td></tr></table>
</div>
</div><p>在从属方案中，优先级与主方案一致，不需要指定优先级。</p>
<p>在为一个替代方案提供了多个候选项的情况下，可以使用 <code>sudo update-alternatives --config &lt;name&gt;</code> 命令，通过交互界面选择方案。</p>
<h1 id="参考文档">参考文档</h1>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p><a href="https://www.dedicatedcore.com/blog/install-gcc-compiler-ubuntu/">How to Install GCC Compiler on Ubuntu 22.04</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p><a href="https://www.jianshu.com/p/6aa5575e8a34">ubuntu ppa源管理 - 简书</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>为apt配置代理</title>
      <link>https://www.zhouxin.space/notes/config-proxy-for-apt/</link>
      <pubDate>Mon, 01 Apr 2024 10:50:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/config-proxy-for-apt/</guid>
      <description>&lt;p&gt;一般来说，apt 通过换源即可获得不错的体验，但有的时候不得不加入一些没被镜像的国外源例如 &lt;code&gt;PPA&lt;/code&gt;，因此不得不琢磨怎么在 apt 中配置代理。&lt;br&gt;
apt 不会从环境变量获取代理配置，需要手动其配置文件 &lt;code&gt;/etc/apt/apt.conf&lt;/code&gt; 中添加：&lt;/p&gt;</description>
      <content:encoded><![CDATA[<p>一般来说，apt 通过换源即可获得不错的体验，但有的时候不得不加入一些没被镜像的国外源例如 <code>PPA</code>，因此不得不琢磨怎么在 apt 中配置代理。<br>
apt 不会从环境变量获取代理配置，需要手动其配置文件 <code>/etc/apt/apt.conf</code> 中添加：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl"><span class="c1"># 配置格式</span>
</span></span><span class="line"><span class="cl">Acquire::http::Proxy <span class="s2">&#34;http://USERNAME:PASSWORD@SERVER:PORT&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">Acquire::https::Proxy <span class="s2">&#34;https://USERNAME:PASSWORD@SERVER:PORT&#34;</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>例如，对于不需要认证的代理，在 <code>/etc/apt/apt.conf</code> 添加以下内容：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl">Acquire::http::Proxy <span class="s2">&#34;http://127.0.0.1:7890&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">Acquire::https::Proxy <span class="s2">&#34;http://127.0.0.1:7890&#34;</span><span class="p">;</span>  
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="参考文档">参考文档</h1>
<ul>
<li><a href="https://askubuntu.com/a/920242">Configure proxy for APT? - Ask Ubuntu</a></li>
</ul>
]]></content:encoded>
    </item>
    <item>
      <title>搭建ZeroTier MOON服务器</title>
      <link>https://www.zhouxin.space/notes/setup-zerotier-moon-server/</link>
      <pubDate>Sun, 31 Mar 2024 11:40:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/setup-zerotier-moon-server/</guid>
      <description>&lt;h1 id=&#34;资源存档&#34;&gt;资源存档&lt;/h1&gt;
&lt;p&gt;原文链接：&lt;a href=&#34;https://www.tpfuture.top/views/linux/net/ZerotierOneAddMoon.html&#34;&gt;ZeroTier-One搭建moon节点 | 一水轩&lt;/a&gt;&lt;br&gt;
ZeroTier 官网：&lt;a href=&#34;https://my.zerotier.com/&#34;&gt;ZeroTier Central&lt;/a&gt;&lt;/p&gt;
&lt;h1 id=&#34;搭建过程&#34;&gt;搭建过程&lt;/h1&gt;
&lt;h2 id=&#34;在服务器上安装并配置-zerotier&#34;&gt;在服务器上安装并配置 ZeroTier&lt;/h2&gt;
&lt;h3 id=&#34;安装-zerotier&#34;&gt;安装 ZeroTier&lt;/h3&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;div class=&#34;chroma&#34;&gt;
&lt;table class=&#34;lntable&#34;&gt;&lt;tr&gt;&lt;td class=&#34;lntd&#34;&gt;
&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code&gt;&lt;span class=&#34;lnt&#34;&gt;1
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;2
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;3
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;4
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;5
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;6
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;7
&lt;/span&gt;&lt;span class=&#34;lnt&#34;&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class=&#34;lntd&#34;&gt;
&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-sh&#34; data-lang=&#34;sh&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;curl -s https://install.zerotier.com &lt;span class=&#34;p&#34;&gt;|&lt;/span&gt; sudo bash
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;sudo systemctl start zerotier-one.service
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;sudo systemctl &lt;span class=&#34;nb&#34;&gt;enable&lt;/span&gt; zerotier-one.service
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;sudo zerotier-cli join &amp;lt;network ID&amp;gt; &lt;span class=&#34;c1&#34;&gt;# 此处填写你的网络的network ID&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id=&#34;在控制台勾选服务器&#34;&gt;在控制台勾选服务器&lt;/h3&gt;
&lt;p&gt;前往对应网络控制台 &lt;a href=&#34;https://my.zerotier.com/&#34;&gt;ZeroTier Central&lt;/a&gt;，允许刚刚添加的设备。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="资源存档">资源存档</h1>
<p>原文链接：<a href="https://www.tpfuture.top/views/linux/net/ZerotierOneAddMoon.html">ZeroTier-One搭建moon节点 | 一水轩</a><br>
ZeroTier 官网：<a href="https://my.zerotier.com/">ZeroTier Central</a></p>
<h1 id="搭建过程">搭建过程</h1>
<h2 id="在服务器上安装并配置-zerotier">在服务器上安装并配置 ZeroTier</h2>
<h3 id="安装-zerotier">安装 ZeroTier</h3>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">curl -s https://install.zerotier.com <span class="p">|</span> sudo bash
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">sudo systemctl start zerotier-one.service
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">sudo systemctl <span class="nb">enable</span> zerotier-one.service
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">sudo zerotier-cli join &lt;network ID&gt; <span class="c1"># 此处填写你的网络的network ID</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h3 id="在控制台勾选服务器">在控制台勾选服务器</h3>
<p>前往对应网络控制台 <a href="https://my.zerotier.com/">ZeroTier Central</a>，允许刚刚添加的设备。</p>
<h2 id="搭建-moon-服务器">搭建 MOON 服务器</h2>
<h3 id="开放端口">开放端口</h3>
<p>MOON 默认使用 UDP 9993 端口，故需要在服务器控制台开放对应入站策略。</p>
<h3 id="生成-moonjson-文件">生成 <code>moon.json</code> 文件</h3>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl"><span class="nb">cd</span> /var/lib/zerotier-one/
</span></span><span class="line"><span class="cl">sudo zerotier-idtool initmoon identity.public &gt; moon.json
</span></span></code></pre></td></tr></table>
</div>
</div><p>使用 <code>vim</code> 等文本编辑工具修改刚刚生成的 <code>moon.json</code> 中 <code>&quot;stableEndpoints&quot;</code> 的值为服务器的公网 IPv4 地址：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl"><span class="o">{</span>
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;id&#34;</span>: <span class="s2">&#34;xxxxx&#34;</span>, <span class="c1"># 这个值后面用于其它设备配置moon</span>
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;objtype&#34;</span>: <span class="s2">&#34;world&#34;</span>,
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;roots&#34;</span>: <span class="o">[</span>
</span></span><span class="line"><span class="cl">  <span class="o">{</span>
</span></span><span class="line"><span class="cl">   <span class="s2">&#34;identity&#34;</span>: <span class="s2">&#34;xxxx:0:eeee&#34;</span>,
</span></span><span class="line"><span class="cl">   <span class="s2">&#34;stableEndpoints&#34;</span>: <span class="o">[</span><span class="s2">&#34;&lt;IPv4 address&gt;/9993&#34;</span><span class="o">]</span> <span class="c1"># 修改这里&lt;IPv4 address&gt;替换为公网地址</span>
</span></span><span class="line"><span class="cl">  <span class="o">}</span>
</span></span><span class="line"><span class="cl"> <span class="o">]</span>,
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;signingKey&#34;</span>: <span class="s2">&#34;asdfasdfasdf&#34;</span>,
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;signingKey_SECRET&#34;</span>: <span class="s2">&#34;asdfasdfasdfasd&#34;</span>,
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;updatesMustBeSignedBy&#34;</span>: <span class="s2">&#34;asdfasdfasdf&#34;</span>,
</span></span><span class="line"><span class="cl"> <span class="s2">&#34;worldType&#34;</span>: <span class="s2">&#34;moon&#34;</span>
</span></span><span class="line"><span class="cl"><span class="o">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>注意，该文件的 <code>id</code> 字段唯一标识了这台设备，该 <code>id</code> 用于其它结点配置 moon。</p>
<h3 id="生成签名文件">生成签名文件</h3>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">zerotier-idtool genmoon moon.json
</span></span></code></pre></td></tr></table>
</div>
</div><p>该命令会生成一个 <code>.moon</code> 文件，通过这个文件，可以把 moon 节点加入网络。</p>
<h3 id="将-moon-节点加入网络">将 moon 节点加入网络</h3>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">mkdir moons.d
</span></span><span class="line"><span class="cl">mv *.moon moons.d/
</span></span><span class="line"><span class="cl">sudo systemctl restart zerotier-one
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="其它设备配置">其它设备配置</h2>
<p>在需要使用 MOON 的设备上安装了 ZeroTier 并加入网络后，还需要手动配置 MOON 节点：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">sudo zerotier-cli orbit &lt;id&gt; &lt;id&gt;  <span class="c1"># 或者在windows上需要管理员权限</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>其中 <code>id</code> 是 MOON 服务器的节点 id，可在 [[#生成<code>moon.json</code>文件]] 这一步生成的 <code>json</code> 中看见，或者在 ZeroTier 网络控制台也可以找到该设备的 id。</p>
]]></content:encoded>
    </item>
    <item>
      <title>CS144 Lab 实验笔记</title>
      <link>https://www.zhouxin.space/notes/cs144-winter-2024-labs/</link>
      <pubDate>Sat, 30 Mar 2024 19:33:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/cs144-winter-2024-labs/</guid>
      <description>&lt;h1 id=&#34;资源存档&#34;&gt;资源存档&lt;/h1&gt;
&lt;p&gt;本次实验使用的课程代码版本为 CS144 Winter 2024，鉴于 CS144 官方要求禁止公开代码以防止抄袭，我将我的题解和原始代码存档放在了 Gitee 上（外国学生应该不知道这个平台吧），有需要可自取：&lt;a href=&#34;https://gitee.com/littleherozzzx/CS144&#34;&gt;CS144: CSS144 Winter 2024 Labs.&lt;/a&gt;。另外，我还托管了课程主页的镜像，各个资源链接如下：&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="资源存档">资源存档</h1>
<p>本次实验使用的课程代码版本为 CS144 Winter 2024，鉴于 CS144 官方要求禁止公开代码以防止抄袭，我将我的题解和原始代码存档放在了 Gitee 上（外国学生应该不知道这个平台吧），有需要可自取：<a href="https://gitee.com/littleherozzzx/CS144">CS144: CSS144 Winter 2024 Labs.</a>。另外，我还托管了课程主页的镜像，各个资源链接如下：</p>
<table>
  <thead>
      <tr>
          <th>名称</th>
          <th>链接</th>
          <th>备注</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>原始代码和题解</td>
          <td><a href="https://gitee.com/littleherozzzx/CS144">CS144: CSS144 Winter 2024 Labs.</a></td>
          <td>原始代码在 archive 分支，题解在 main 分支</td>
      </tr>
      <tr>
          <td>课程主页镜像</td>
          <td><a href="https://littleherozzzx.github.io/cs144Winter2024.github.io/">CS 144: Introduction to Computer Networking</a></td>
          <td></td>
      </tr>
      <tr>
          <td>虚拟机镜像和配置过程</td>
          <td><a href="https://web.stanford.edu/class/cs144/vm_howto/vm-howto-image.html">Setting up your CS144 VM using VirtualBox</a></td>
          <td>百度云链接：<a href="https://pan.baidu.com/s/1s7xWKn5ccph64--rdJOz6g?pwd=ozb0">https://pan.baidu.com/s/1s7xWKn5ccph64--rdJOz6g?pwd=ozb0</a></td>
      </tr>
  </tbody>
</table>
<h2 id="虚拟机镜像">虚拟机镜像</h2>
<p>CS144 官网给出了 Virtual Box 镜像及相应配置过程：<a href="https://web.stanford.edu/class/cs144/vm_howto/vm-howto-image.html">Setting up your CS144 VM using VirtualBox</a>。</p>
<h1 id="lab-0">Lab 0</h1>
<h2 id="环境配置">环境配置</h2>
<p>我使用的是 <code>Ubuntu 22.04 @ WSL2</code>，原文档给出了一个环境配置命令：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-sh" data-lang="sh"><span class="line"><span class="cl">sudo apt update <span class="o">&amp;&amp;</span> sudo apt install git cmake gdb build-essential clang clang-tidy clang-format gcc-doc pkg-config glibc-doc tcpdump tshark
</span></span></code></pre></td></tr></table>
</div>
</div><p>文档中提到测试环境是 <code>Ubuntu 23.10 LTS</code>+<code>g++ 13.2</code>，而上述命令并不能安装对应版本的 gcc，可以参考这篇文章安装最新的 <code>g++</code>：<a href="../../Ubuntu/%E5%AE%89%E8%A3%85%E5%B9%B6%E5%88%87%E6%8D%A2%E6%8C%87%E5%AE%9Agcc%E6%88%96%E8%80%85g++%E7%89%88%E6%9C%AC.md">安装并切换指定gcc或者g++版本</a>，在 Ubuntu 22 上最新只能安装 13.1 版本的 g++。后续实验均在此基础上进行。</p>
<h2 id="现代-c">现代 C++</h2>
<p>实验要求使用现代 C++ 风格进行编程，基本理念是：每个对象都只设计尽可能少的公共接口、内部存在各种安全检查、使用结束后应该正确回收垃圾，避免使用成对的关键字（例如 <code>new</code> 和 <code>delete</code>）。相反，通过构造函数和析构函数来获取和释放资源，即基于“资源获取即初始化”RAII 理念。</p>
<p>具体来说，对于编码风格有以下要求：</p>
<ul>
<li>在编码过程中参考文档 <a href="https://en.cppreference.com/w/">cppreference.com</a></li>
<li>不要使用 <code>malloc</code>、<code>free</code>、<code>new</code> 或者 <code>delete</code> 关键字</li>
<li>不要使用原始指针，使用智能指针</li>
<li>不要使用模板、线程、锁或者虚函数</li>
<li>不要使用 <code>C</code> 风格字符串 <code>char*</code> 或者相关函数 <code>strlen()</code> 等</li>
<li>不要使用 <code>C</code> 风格类型转换，使用 <code>C++</code> 的 <code>static_cast</code> 进行转换</li>
<li>函数形参尽可能使用 <code>const</code> 关键字</li>
<li>变量和函数都尽可能使用 <code>const</code> 关键字修饰</li>
<li>避免使用全局变量，每个变量的作用域都应该尽可能小</li>
<li>在提交前，使用 <code>cmake --build build --target tidy</code> 获取关于代码风格修改的建议，使用 <code> cmake --build build --target format</code> 对代码进行格式化。</li>
</ul>
<h2 id="writing-webget">Writing webget</h2>
<p>忽略前面通过 <code>telnet</code> 刚问网页和发送邮件的内容，第一个编码任务是完成 <code>Webget</code>，使之能够获取网页。这个任务比较简单，涉及到一点网络编程的知识。<br>
整个任务的流程是：根据形参获取初始化主机地址，建立与该主机的 TCP 连接，发送 HTTP 请求报文（包含形参中的资源路径），打印响应报文，关闭 TCP 连接。<br>
实现的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="kt">void</span> <span class="nf">get_URL</span><span class="p">(</span> <span class="k">const</span> <span class="n">string</span><span class="o">&amp;</span> <span class="n">host</span><span class="p">,</span> <span class="k">const</span> <span class="n">string</span><span class="o">&amp;</span> <span class="n">path</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">Address</span> <span class="n">addr</span> <span class="o">=</span> <span class="n">Address</span><span class="p">(</span><span class="n">host</span><span class="p">,</span> <span class="s">&#34;http&#34;</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">TCPSocket</span> <span class="n">sock</span> <span class="o">=</span> <span class="n">TCPSocket</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">sock</span><span class="p">.</span><span class="n">connect</span><span class="p">(</span><span class="n">addr</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">string</span> <span class="n">message</span> <span class="o">=</span> <span class="s">&#34;GET &#34;</span> <span class="o">+</span> <span class="n">path</span> <span class="o">+</span><span class="s">&#34; HTTP/1.1</span><span class="se">\r\n</span><span class="s">&#34;</span> <span class="o">+</span> <span class="s">&#34;Host: &#34;</span><span class="o">+</span><span class="n">host</span> <span class="o">+</span> <span class="s">&#34;</span><span class="se">\r\n</span><span class="s">&#34;</span> <span class="o">+</span><span class="s">&#34;Connection: close</span><span class="se">\r\n\r\n</span><span class="s">&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">sock</span><span class="p">.</span><span class="n">write</span><span class="p">(</span><span class="n">message</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">while</span><span class="p">(</span><span class="o">!</span><span class="n">sock</span><span class="p">.</span><span class="n">eof</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">    <span class="n">string</span> <span class="n">response</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">sock</span><span class="p">.</span><span class="n">read</span><span class="p">(</span><span class="n">response</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">cout</span> <span class="o">&lt;&lt;</span> <span class="n">response</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="n">sock</span><span class="p">.</span><span class="n">close</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">cerr</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;Function called: get_URL(&#34;</span> <span class="o">&lt;&lt;</span> <span class="n">host</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;, &#34;</span> <span class="o">&lt;&lt;</span> <span class="n">path</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;)</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">cerr</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;Warning: get_URL() has not been implemented yet.</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="an-in-memory-reliable-byte-stream">An in-memory reliable byte stream</h2>
<p>第二个任务是实现可靠的内存字节流，有以下几个要求：</p>
<ul>
<li>输出端和输入端数据顺序一致，以 EOF 结尾</li>
<li>流量控制，即该字节流存在一个容量上限</li>
<li>容量上限指的是字节流中存在的数据的上限，而非发送者发送的字节流的上限。显然，我在实现时直接截断了超过剩余容量的输入</li>
<li>单线程使用，不需要考虑并发读写</li>
</ul>
<p>任务要求实现如下接口：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Writer</span> <span class="o">:</span> <span class="k">public</span> <span class="n">ByteStream</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">  <span class="kt">void</span> <span class="n">push</span><span class="p">(</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">data</span> <span class="p">);</span> <span class="c1">// Push data to stream, but only as much as available capacity allows.
</span></span></span><span class="line"><span class="cl">  <span class="kt">void</span> <span class="nf">close</span><span class="p">();</span>                  <span class="c1">// Signal that the stream has reached its ending. Nothing more will be written.
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="kt">bool</span> <span class="nf">is_closed</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>              <span class="c1">// Has the stream been closed?
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="nf">available_capacity</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span> <span class="c1">// How many bytes can be pushed to the stream right now?
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="nf">bytes_pushed</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>       <span class="c1">// Total number of bytes cumulatively pushed to the stream
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Reader</span> <span class="o">:</span> <span class="k">public</span> <span class="n">ByteStream</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl"><span class="k">public</span><span class="o">:</span>
</span></span><span class="line"><span class="cl">  <span class="n">std</span><span class="o">::</span><span class="n">string_view</span> <span class="n">peek</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span> <span class="c1">// Peek at the next bytes in the buffer
</span></span></span><span class="line"><span class="cl">  <span class="kt">void</span> <span class="nf">pop</span><span class="p">(</span> <span class="kt">uint64_t</span> <span class="n">len</span> <span class="p">);</span>      <span class="c1">// Remove `len` bytes from the buffer
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="kt">bool</span> <span class="nf">is_finished</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>        <span class="c1">// Is the stream finished (closed and fully popped)?
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="nf">bytes_buffered</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span> <span class="c1">// Number of bytes currently buffered (pushed and not popped)
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="nf">bytes_popped</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>   <span class="c1">// Total number of bytes cumulatively popped from stream
</span></span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>为了记录累计读写量、维护剩余容量和端口是否关闭，在 <code>ByteStream</code> 添加了如下成员变量（别忘了在构造函数中初始化）：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl">  <span class="n">std</span><span class="o">::</span><span class="n">queue</span><span class="o">&lt;</span><span class="kt">char</span><span class="o">&gt;</span> <span class="n">buffer_</span><span class="p">;</span> <span class="c1">// 缓冲区
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">amount_</span><span class="p">;</span> <span class="c1">// 剩余容量
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">total_pushed_</span><span class="p">;</span> <span class="c1">// 总写入量
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">total_poped_</span><span class="p">;</span> <span class="c1">// 总读取量
</span></span></span><span class="line"><span class="cl">  <span class="kt">bool</span> <span class="n">close_</span><span class="p">;</span> <span class="c1">// 端口状态
</span></span></span></code></pre></td></tr></table>
</div>
</div><p>具体实现比较简单，维护一个队列 <code>vector&lt;string&gt;</code> 进行读写操作。在 <code>Writer::push</code> 的实现中，如果待写入数据超过了缓冲区剩余容量，则直接截断即可。指的注意的是 pop 采用了一种“lazy pop”的机制，即每次 pop 一个字节时，不要直接删除队头字符串的第一个字符，而是使用一个变量记录对头字符串还剩多少字节没有被 pop。</p>
<p><code>byte_stream.cc</code> 的实现如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span><span class="lnt">68
</span><span class="lnt">69
</span><span class="lnt">70
</span><span class="lnt">71
</span><span class="lnt">72
</span><span class="lnt">73
</span><span class="lnt">74
</span><span class="lnt">75
</span><span class="lnt">76
</span><span class="lnt">77
</span><span class="lnt">78
</span><span class="lnt">79
</span><span class="lnt">80
</span><span class="lnt">81
</span><span class="lnt">82
</span><span class="lnt">83
</span><span class="lnt">84
</span><span class="lnt">85
</span><span class="lnt">86
</span><span class="lnt">87
</span><span class="lnt">88
</span><span class="lnt">89
</span><span class="lnt">90
</span><span class="lnt">91
</span><span class="lnt">92
</span><span class="lnt">93
</span><span class="lnt">94
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;byte_stream.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;iostream&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">ByteStream</span><span class="o">::</span><span class="n">ByteStream</span><span class="p">(</span> <span class="kt">uint64_t</span> <span class="n">capacity</span> <span class="p">)</span> <span class="o">:</span>
</span></span><span class="line"><span class="cl">  <span class="n">capacity_</span><span class="p">(</span> <span class="n">capacity</span> <span class="p">),</span> <span class="n">buffer_</span><span class="p">(),</span> <span class="n">amount_</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">total_pushed_</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">  <span class="n">total_poped_</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">first_string_left_size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">close_</span><span class="p">(</span> <span class="nb">false</span> <span class="p">),</span> <span class="n">error_</span><span class="p">(</span> <span class="nb">false</span> <span class="p">)</span>  <span class="p">{}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">bool</span> <span class="n">Writer</span><span class="o">::</span><span class="n">is_closed</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">close_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">Writer</span><span class="o">::</span><span class="n">push</span><span class="p">(</span> <span class="n">string</span> <span class="n">data</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">free_capacity</span> <span class="o">=</span> <span class="n">available_capacity</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">to_push_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">free_capacity</span><span class="p">,</span> <span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">to_push_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span>  <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">data</span><span class="p">.</span><span class="n">resize</span><span class="p">(</span><span class="n">to_push_size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">buffer_</span><span class="p">.</span><span class="n">emplace</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">data</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">buffer_</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">first_string_left_size</span> <span class="o">=</span> <span class="n">to_push_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">total_pushed_</span> <span class="o">+=</span> <span class="n">to_push_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">amount_</span> <span class="o">+=</span> <span class="n">to_push_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">Writer</span><span class="o">::</span><span class="n">close</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">    <span class="n">close_</span> <span class="o">=</span> <span class="nb">true</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">Writer</span><span class="o">::</span><span class="n">available_capacity</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">capacity_</span> <span class="o">-</span> <span class="n">amount_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">Writer</span><span class="o">::</span><span class="n">bytes_pushed</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">total_pushed_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">bool</span> <span class="n">Reader</span><span class="o">::</span><span class="n">is_finished</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">amount_</span> <span class="o">==</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">close_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">Reader</span><span class="o">::</span><span class="n">bytes_popped</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">total_poped_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">string_view</span> <span class="n">Reader</span><span class="o">::</span><span class="n">peek</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">amount_</span> <span class="o">==</span> <span class="mi">0</span> <span class="o">||</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">empty</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">string_view</span><span class="p">{};</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">const</span> <span class="n">string</span><span class="o">&amp;</span> <span class="n">front</span> <span class="o">=</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">front</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="c1">//  return string_view(front.data()+front.size()-first_string_left_size,1);
</span></span></span><span class="line"><span class="cl"><span class="c1">//  return string_view(&amp;front[front.size()-first_string_left_size]);
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="nf">string_view</span><span class="p">(</span><span class="n">front</span><span class="p">).</span><span class="n">substr</span><span class="p">(</span><span class="n">front</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="o">-</span><span class="n">first_string_left_size</span><span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">Reader</span><span class="o">::</span><span class="n">pop</span><span class="p">(</span> <span class="kt">uint64_t</span> <span class="n">len</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="n">total_poped_</span> <span class="o">+=</span> <span class="n">len</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">amount_</span> <span class="o">-=</span> <span class="n">len</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">while</span><span class="p">(</span><span class="n">len</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">len</span> <span class="o">&gt;=</span> <span class="n">first_string_left_size</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">len</span> <span class="o">-=</span> <span class="n">first_string_left_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">buffer_</span><span class="p">.</span><span class="n">pop</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="n">first_string_left_size</span> <span class="o">=</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">front</span><span class="p">().</span><span class="n">size</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span> <span class="k">else</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">first_string_left_size</span> <span class="o">-=</span> <span class="n">len</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">len</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">Reader</span><span class="o">::</span><span class="n">bytes_buffered</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">amount_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>最终吞吐量最高跑到了 34 Gbit/s。<br>
<img alt="Lab0 实验结果" loading="lazy" src="http://pics.zhouxin.space/20240410101229.png"></p>
<h1 id="lab-1">Lab 1</h1>
<h2 id="putting-substrings-in-sequence">Putting substrings in sequence</h2>
<p>这个模块要求实现一个 TCP 包重组模块，我感觉就是实现计网中 GBN 算法中的接受窗口，缓存收到的处于接收窗口内的 TCP 包、对其按序重组，并及时写入 Lab 0 中实现的可靠内存字节流中。做下来发现这个任务有以下几个要求：</p>
<ul>
<li>实现包重组，包括乱序、重复、过期、截断等</li>
<li>该模块缓冲区不得大于内存字节流中的可用缓冲区大小
<ul>
<li>如果包过长，则截断保存</li>
</ul>
</li>
</ul>
<p>每个包到达时，有三个字段标识数据内容 <code>data</code>、包序号 <code>first_index</code> 和是否为最后一个包 <code>is_last_substring</code>，对于乱序到达的数据报，我们要暂存这些信息，我使用如下一个结构体保存每一个数据报：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="k">struct</span> <span class="nc">reassembler_item</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">data</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">first_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">last_index</span><span class="p">;</span> <span class="c1">// 左闭右开
</span></span></span><span class="line"><span class="cl">  <span class="kt">bool</span> <span class="n">is_last</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="kt">bool</span> <span class="k">operator</span> <span class="o">&lt;</span> <span class="p">(</span><span class="k">const</span> <span class="n">reassembler_item</span><span class="o">&amp;</span> <span class="n">x</span><span class="p">)</span> <span class="k">const</span><span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">first_index</span> <span class="o">&lt;</span> <span class="n">x</span><span class="p">.</span><span class="n">first_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="n">reassembler_item</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">data1</span><span class="p">,</span> <span class="kt">uint64_t</span> <span class="n">first_index1</span><span class="p">,</span> <span class="kt">uint64_t</span> <span class="n">last_index1</span><span class="p">,</span> <span class="kt">bool</span> <span class="n">is_last1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="o">:</span> <span class="n">data</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">data1</span><span class="p">)),</span>
</span></span><span class="line"><span class="cl">    <span class="n">first_index</span><span class="p">(</span><span class="n">first_index1</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">last_index</span><span class="p">(</span><span class="n">last_index1</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">is_last</span><span class="p">(</span><span class="n">is_last1</span><span class="p">)</span> <span class="p">{}</span>
</span></span><span class="line"><span class="cl"><span class="p">};</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>为了方便比较，我引入了一个字段用于表示这个包的数据表示的序号范围，采用左闭右开区间是因为存在一些空串（用来标识数据已经发送结束），其右闭区间为 -1，对于无符号数下溢了。</p>
<p>使用 <code>vector</code> 暂存收到的乱序数据报，并维护保证其始终有序且不存在重复元素。具体来说，在每次插入数据报时，使用 <code>std::lower_bound</code> 二分查找其待插入位置。找到插入位置后，待插入数据报可能向后覆盖了好几个已收到的数据报（例如，新收到的数据范围为 100<del>200，但是 110</del>120、190~210 范围的数据报在此之间已经收到并且保存在本模块缓冲区中），因此检查待插入位置后面可能被覆盖的元素，被待插入数据报完全覆盖的数据报直接扔掉，不完全覆盖的数据报则先拼接到待插入的数据报中，然后再扔掉。同样地，待插入数据报也有可能被待插入位置前的数据报覆盖，如果被完全覆盖了，则直接扔掉待插入数据报；如果被不完全覆盖，则拼接到前一个数据报后再扔掉。只有没被覆盖的数据报才需要被单独插入到模块内部暂存区中。<br>
注意，上文所说的覆盖包含无重叠但相邻的情况，即 [1,200) 和 [200,300) 这两个数据包也是可以合并的。这可以保证如果有字符串可以向内存缓冲区写入，则这个字符串一定是且仅是暂存区的第一个数据包。</p>
<p>只有当暂存区新插入数据包时，才需要检查暂存区数据能否写入内存缓冲区。暂存区 <code>insert</code> 方法实现如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">Reassembler</span><span class="o">::</span><span class="n">insert</span><span class="p">(</span> <span class="kt">uint64_t</span> <span class="n">first_index</span><span class="p">,</span> <span class="n">string</span> <span class="n">data</span><span class="p">,</span> <span class="kt">bool</span> <span class="n">is_last_substring</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">capacity</span> <span class="o">=</span> <span class="n">output_</span><span class="p">.</span><span class="n">writer</span><span class="p">().</span><span class="n">available_capacity</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// 可以接受的序号范围为[current_index, current_index+capacity)  左闭右开
</span></span></span><span class="line"><span class="cl">  <span class="c1">// data中数据的序号范围为[first_index, first_index+data.size())
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 二者取交集，若为空说明该串过期或者太早到来
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">left_bound</span> <span class="o">=</span> <span class="n">max</span><span class="p">(</span><span class="n">first_index</span><span class="p">,</span> <span class="n">current_index_</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">right_bound</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">current_index_</span><span class="o">+</span><span class="n">capacity</span><span class="p">,</span> <span class="n">first_index</span><span class="o">+</span><span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">right_bound</span> <span class="o">&lt;</span> <span class="n">left_bound</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 相等为空串，也能接受（可能标志了last_string）
</span></span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span> <span class="c1">// 对于buffer_没有更新操作，后续不会向缓冲区写入
</span></span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="n">reassembler_item</span> <span class="n">item</span> <span class="o">=</span> <span class="n">reassembler_item</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">data</span><span class="p">.</span><span class="n">substr</span><span class="p">(</span> <span class="n">left_bound</span><span class="o">-</span><span class="n">first_index</span><span class="p">,</span> <span class="n">right_bound</span><span class="o">-</span><span class="n">left_bound</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">left_bound</span><span class="p">,</span> <span class="n">right_bound</span><span class="p">,</span> <span class="n">is_last_substring</span> <span class="o">&amp;&amp;</span> <span class="n">right_bound</span> <span class="o">==</span> <span class="n">first_index</span><span class="o">+</span><span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="n">pending_size_</span> <span class="o">+=</span> <span class="n">item</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="c1">// 先全部加进去，后面根据覆盖的内容再移除
</span></span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="n">insert_iter</span> <span class="o">=</span> <span class="n">lower_bound</span><span class="p">(</span><span class="n">buffer_</span><span class="p">.</span><span class="n">begin</span><span class="p">(),</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">end</span><span class="p">(),</span> <span class="n">item</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// 先判断item是否向后覆盖了其它已插入buffer_的数据,如果有则合并
</span></span></span><span class="line"><span class="cl">  <span class="k">auto</span> <span class="n">iter</span> <span class="o">=</span> <span class="n">insert_iter</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">while</span> <span class="p">(</span><span class="n">iter</span> <span class="o">!=</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">end</span><span class="p">()</span> <span class="o">&amp;&amp;</span> <span class="n">item</span><span class="p">.</span><span class="n">last_index</span> <span class="o">&gt;=</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">first_index</span> <span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">item</span><span class="p">.</span><span class="n">last_index</span> <span class="o">&lt;</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 只有部分覆盖才要合并，全覆盖直接erase即可
</span></span></span><span class="line"><span class="cl">      <span class="n">item</span><span class="p">.</span><span class="n">data</span> <span class="o">+=</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">data</span><span class="p">.</span><span class="n">substr</span><span class="p">(</span><span class="n">item</span><span class="p">.</span><span class="n">last_index</span><span class="o">-</span><span class="n">iter</span><span class="o">-&gt;</span><span class="n">first_index</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">      <span class="c1">// 覆盖长度为item_last-iter_first
</span></span></span><span class="line"><span class="cl">      <span class="n">pending_size_</span> <span class="o">-=</span> <span class="n">item</span><span class="p">.</span><span class="n">last_index</span> <span class="o">-</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">first_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">item</span><span class="p">.</span><span class="n">last_index</span> <span class="o">=</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">item</span><span class="p">.</span><span class="n">is_last</span> <span class="o">|=</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">is_last</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">pending_size_</span> <span class="o">-=</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">iter</span> <span class="o">=</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">erase</span><span class="p">(</span><span class="n">iter</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// 再判断前一个数据是否覆盖了item
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 被前一个覆盖直接在前一个元素中修改，而不需要再插入item了
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">insert_iter</span> <span class="o">!=</span> <span class="n">buffer_</span><span class="p">.</span><span class="n">begin</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">    <span class="n">iter</span> <span class="o">=</span> <span class="n">insert_iter</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span> <span class="o">&gt;=</span> <span class="n">item</span><span class="p">.</span><span class="n">first_index</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span> <span class="o">&lt;</span> <span class="n">item</span><span class="p">.</span><span class="n">last_index</span><span class="p">){</span> <span class="c1">// 非完全覆盖
</span></span></span><span class="line"><span class="cl">        <span class="n">iter</span><span class="o">-&gt;</span><span class="n">data</span> <span class="o">+=</span> <span class="n">item</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">substr</span><span class="p">(</span><span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span><span class="o">-</span><span class="n">item</span><span class="p">.</span><span class="n">first_index</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="n">pending_size_</span> <span class="o">-=</span> <span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span> <span class="o">-</span> <span class="n">item</span><span class="p">.</span><span class="n">first_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">iter</span><span class="o">-&gt;</span><span class="n">last_index</span> <span class="o">=</span> <span class="n">item</span><span class="p">.</span><span class="n">last_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">iter</span><span class="o">-&gt;</span><span class="n">is_last</span> <span class="o">|=</span> <span class="n">item</span><span class="p">.</span><span class="n">is_last</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span> <span class="k">else</span> <span class="p">{</span> <span class="c1">// 完全覆盖
</span></span></span><span class="line"><span class="cl">        <span class="n">pending_size_</span> <span class="o">-=</span> <span class="n">item</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="c1">// 没插入，不需要删除的代码
</span></span></span><span class="line"><span class="cl">      <span class="c1">// 直接return，不要运行后面插入insert代码
</span></span></span><span class="line"><span class="cl">      <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// insert item into buffer_
</span></span></span><span class="line"><span class="cl">  <span class="n">buffer_</span><span class="p">.</span><span class="n">insert</span><span class="p">(</span><span class="n">insert_iter</span><span class="p">,</span> <span class="n">item</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// 只有插入了新的item，才有可能需要向缓冲区写入
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">buffer_</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">first_index</span> <span class="o">==</span> <span class="n">current_index_</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">auto</span><span class="o">&amp;</span> <span class="n">to_write_item</span> <span class="o">=</span> <span class="n">buffer_</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_</span><span class="p">.</span><span class="n">writer</span><span class="p">().</span><span class="n">push</span><span class="p">(</span><span class="n">to_write_item</span><span class="p">.</span><span class="n">data</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">pending_size_</span> <span class="o">-=</span> <span class="n">to_write_item</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="n">current_index_</span> <span class="o">=</span> <span class="n">to_write_item</span><span class="p">.</span><span class="n">last_index</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">to_write_item</span><span class="p">.</span><span class="n">is_last</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">output_</span><span class="p">.</span><span class="n">writer</span><span class="p">().</span><span class="n">close</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">buffer_</span><span class="p">.</span><span class="n">erase</span><span class="p">(</span><span class="n">buffer_</span><span class="p">.</span><span class="n">begin</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>最终重组模块吞吐量最高跑到了 10 Gbit/s。<br>
<img alt="Lab1 实验结果" loading="lazy" src="http://pics.zhouxin.space/20240410101326.png"></p>
<h1 id="lab-2">Lab 2</h1>
<p>到此为止，我们已经完成了内存可靠字节流和 TCP 包重组模块，重组模块将收到的 TCP 包进行重组，并及时写入内存字节流。接下来，我们需要写一个 TCP 接收器模块，接收来自 peer 发送方的消息，并回复 ACK 和接收窗口大小。</p>
<p>在此之前，有一个数据格式问题：在前两个模块中，我们使用 <code>uint64</code> 来标记序列号，可是在 TCP 的数据包只有 32 位用于记录序号，并且初始包（SYN）的序号可能是随机的。因此，我们首先要实现一个 32 位 TCP 包序号和 64 位绝对序号互相转换的模块。前者开始序号随机，并不断自增取余；后者固定从 0 开始自增，且我们认为总数据量不可能超过 2^64Byte，即 2^34GB。</p>
<h2 id="translating-between-64-bit-indexes-and-32-bit-seqnos">Translating between 64-bit indexes and 32-bit seqnos</h2>
<p><img alt="TCP包序号、绝对序号和流序号之间的对应关系" loading="lazy" src="http://pics.zhouxin.space/20240411101104.png"><br>
根据上图定义，不难发现 seqno 和 abs seqno 存在如下对应关系：</p>
<p>$$
seqno = (absSeqno+zeroPoint) % 2^{32}
$$</p>
<p>从 64 位转 32 位根据上式转换即可，其中对 2^32 取余是不必要的，因为 32 位数自动截断高 32 位。</p>
<p>从 32 位向 64 位转换，我们需要分开考虑其高低 32 位。首先是低 32 位，低 32 位标识了这个包的是整个序列的第 $absSeq%2^{32}$ 个包。那怎么通过 $seqno$ 计算它是整个序列的第几个包呢？$seqno$ 在自增过程中会不断取余，若不取余，记其为 $seqno&rsquo;$，那么这个包是整个序列的第 $seqno&rsquo;-zeroPoint$ 个包，而 $seqno&rsquo;=seqno+n\times 2^{32}$，即：</p>
<p>$$
absSeq % 2^{32} = (seqno&rsquo;-zeroPoint)%2^{32}
= (seqno+n\times 2^{32} - zeroPoint)%2^{32}
= (seqno-zeroPoint + 2^{32}) % 2^{32}
$$</p>
<p>上式即为计算绝对序号低 32 位的方法。得到低 32 位后，就要根据 checkPoint 得到高 32 位。显然，为了接近 checkPoint，高 32 位也是越接近越好，因此高 32 位可以为 checkPoint 的高 32 位或者在此基础上±1，然后比较这三个方案哪个更接近 checkPoint 即可。</p>
<p><code>wrapping_integers.cc</code> 实现为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;wrapping_integers.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">Wrap32</span> <span class="n">Wrap32</span><span class="o">::</span><span class="n">wrap</span><span class="p">(</span> <span class="kt">uint64_t</span> <span class="n">n</span><span class="p">,</span> <span class="n">Wrap32</span> <span class="n">zero_point</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">Wrap32</span> <span class="p">{</span> <span class="n">Wrap32</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="o">+</span> <span class="n">zero_point</span><span class="p">.</span><span class="n">raw_value_</span>  <span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">Wrap32</span><span class="o">::</span><span class="n">unwrap</span><span class="p">(</span> <span class="n">Wrap32</span> <span class="n">zero_point</span><span class="p">,</span> <span class="kt">uint64_t</span> <span class="n">checkpoint</span> <span class="p">)</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="c1">// 转换为从0开始的绝对编号
</span></span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="c1">// checkpoint = 前32位+left
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 与checkpoint最近的可能有两个数（分布在checkpoint一左一右） 其中一个必定是 前32位+offset
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 如果offset &lt; left  那么另一个必定比checkpoint大 等于前32位+zero_point+0x1 0000 0000
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 那么就要看checkpoint 更接近前32位+zero_point 还是前32位+zero_point+0x1 0000 0000
</span></span></span><span class="line"><span class="cl">  <span class="c1">// 两边同减去前32位和zero_point 就是看 left-point 更接近0 还是0x 1 0000 0000
</span></span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">offset</span> <span class="o">=</span> <span class="p">(</span><span class="n">raw_value_</span><span class="o">+</span><span class="mh">0x1&#39;0000&#39;0000</span><span class="o">-</span><span class="n">zero_point</span><span class="p">.</span><span class="n">raw_value_</span><span class="p">)</span><span class="o">%</span><span class="mh">0x1&#39;0000&#39;0000</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint32_t</span> <span class="n">left</span> <span class="o">=</span> <span class="n">checkpoint</span> <span class="o">%</span> <span class="mh">0x1&#39;0000&#39;0000</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">high32</span> <span class="o">=</span> <span class="n">checkpoint</span> <span class="o">-</span> <span class="n">left</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="c1">//  if( offset == left) {
</span></span></span><span class="line"><span class="cl"><span class="c1">//    return high32+checkpoint;
</span></span></span><span class="line"><span class="cl"><span class="c1">//  } else if ( offset &lt; left){
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">offset</span> <span class="o">&lt;</span> <span class="n">left</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">left</span><span class="o">-</span> <span class="n">offset</span> <span class="o">&lt;=</span> <span class="mh">0x8000&#39;0000</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 更接近前32位+zero_point
</span></span></span><span class="line"><span class="cl">      <span class="k">return</span> <span class="n">high32</span><span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="k">return</span> <span class="n">high32</span><span class="o">+</span> <span class="n">offset</span> <span class="o">+</span><span class="mh">0x1&#39;0000&#39;0000</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// 同上，offset &gt; left 那么另一个一定比check_point 小 等于前32位+zero_point-0x1 0000 0000
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span> <span class="n">high32</span> <span class="o">==</span> <span class="mi">0</span> <span class="o">||</span> <span class="n">offset</span> <span class="o">-</span><span class="n">left</span> <span class="o">&lt;=</span> <span class="mh">0x8000&#39;0000</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 更接近前32位+zero_point
</span></span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">high32</span><span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">high32</span><span class="o">+</span> <span class="n">offset</span> <span class="o">-</span><span class="mh">0x1&#39;0000&#39;0000</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h2 id="implementing-the-tcp-receiver">Implementing the TCP receiver</h2>
<p>接下来我们就可以实现 TCP receiver 了，实验过程中注意区分五个序号的概念，很容易搞混。另有几个关键逻辑值得一提：</p>
<ul>
<li>如果收到 RST，需要将向内存字节流报告出错（很奇怪为啥 <code>set_eroor</code> 方法是 <code>Reader</code> 而不是 <code>Writer</code> 的）；</li>
<li>收到 SYN 后更新 <code>zero_point</code> 和 <code>ack_</code>；</li>
<li>只有收到 SYN 后才能开始接收数据；</li>
<li>向包重组器发送数据后，根据内存中写入的数据量可以得到第一个待接收的数据的序号，进而更新 <code>ack_</code>；</li>
<li>如果数据全部接收完毕，<code>ack_</code> 更新时还要额外 +1（FIN 占了一个序号），接收完毕需要根据 <code>writer.is_closed</code> 来判断;</li>
</ul>
<p><code>TCP_receiver</code> 实现如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;tcp_receiver.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">TCPReceiver</span><span class="o">::</span><span class="n">receive</span><span class="p">(</span> <span class="n">TCPSenderMessage</span> <span class="n">message</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">RST</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">reader</span><span class="p">().</span><span class="n">set_error</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">SYN</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">zero_point_</span> <span class="o">=</span> <span class="n">Wrap32</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">seqno</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ack_</span><span class="p">.</span><span class="n">emplace</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">seqno</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">ack_</span><span class="p">.</span><span class="n">has_value</span><span class="p">())</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="k">const</span> <span class="kt">uint64_t</span> <span class="n">check_point</span> <span class="o">=</span> <span class="n">writer</span><span class="p">().</span><span class="n">bytes_pushed</span><span class="p">()</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">uint64_t</span> <span class="n">first_index</span>
</span></span><span class="line"><span class="cl">      <span class="o">=</span> <span class="n">Wrap32</span><span class="p">(</span> <span class="n">message</span><span class="p">.</span><span class="n">SYN</span> <span class="o">?</span> <span class="n">message</span><span class="p">.</span><span class="n">seqno</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">:</span> <span class="n">message</span><span class="p">.</span><span class="n">seqno</span> <span class="p">).</span><span class="n">unwrap</span><span class="p">(</span> <span class="n">zero_point_</span><span class="p">,</span> <span class="n">check_point</span> <span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">reassembler_</span><span class="p">.</span><span class="n">insert</span><span class="p">(</span> <span class="n">first_index</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">payload</span><span class="p">),</span> <span class="n">message</span><span class="p">.</span><span class="n">FIN</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">ack_</span> <span class="o">=</span> <span class="n">ack_</span><span class="o">-&gt;</span><span class="n">wrap</span><span class="p">(</span><span class="n">writer</span><span class="p">().</span><span class="n">bytes_pushed</span><span class="p">()</span><span class="o">+</span><span class="mi">1</span><span class="o">+</span><span class="n">writer</span><span class="p">().</span><span class="n">is_closed</span><span class="p">()</span> <span class="p">,</span> <span class="n">zero_point_</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">TCPReceiverMessage</span> <span class="n">TCPReceiver</span><span class="o">::</span><span class="n">send</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="p">{</span><span class="n">ack_</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">           <span class="k">static_cast</span><span class="o">&lt;</span><span class="kt">uint16_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">min</span><span class="p">(</span><span class="n">reassembler_</span><span class="p">.</span><span class="n">writer</span><span class="p">().</span><span class="n">available_capacity</span><span class="p">(),</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="kt">uint64_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">UINT16_MAX</span><span class="p">))),</span>
</span></span><span class="line"><span class="cl">           <span class="n">reader</span><span class="p">().</span><span class="n">has_error</span><span class="p">()};</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>运行结果为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span><span class="lnt">68
</span><span class="lnt">69
</span><span class="lnt">70
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">-- Building in &#39;Debug&#39; mode.
</span></span><span class="line"><span class="cl">-- Configuring done (0.3s)
</span></span><span class="line"><span class="cl">-- Generating done (0.1s)
</span></span><span class="line"><span class="cl">-- Build files have been written to: /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">Test project /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">      Start  1: compile with bug-checkers
</span></span><span class="line"><span class="cl"> 1/29 Test  #1: compile with bug-checkers ........   Passed   19.75 sec
</span></span><span class="line"><span class="cl">      Start  3: byte_stream_basics
</span></span><span class="line"><span class="cl"> 2/29 Test  #3: byte_stream_basics ...............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  4: byte_stream_capacity
</span></span><span class="line"><span class="cl"> 3/29 Test  #4: byte_stream_capacity .............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  5: byte_stream_one_write
</span></span><span class="line"><span class="cl"> 4/29 Test  #5: byte_stream_one_write ............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  6: byte_stream_two_writes
</span></span><span class="line"><span class="cl"> 5/29 Test  #6: byte_stream_two_writes ...........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  7: byte_stream_many_writes
</span></span><span class="line"><span class="cl"> 6/29 Test  #7: byte_stream_many_writes ..........   Passed    0.06 sec
</span></span><span class="line"><span class="cl">      Start  8: byte_stream_stress_test
</span></span><span class="line"><span class="cl"> 7/29 Test  #8: byte_stream_stress_test ..........   Passed    0.05 sec
</span></span><span class="line"><span class="cl">      Start  9: reassembler_single
</span></span><span class="line"><span class="cl"> 8/29 Test  #9: reassembler_single ...............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 10: reassembler_cap
</span></span><span class="line"><span class="cl"> 9/29 Test #10: reassembler_cap ..................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 11: reassembler_seq
</span></span><span class="line"><span class="cl">10/29 Test #11: reassembler_seq ..................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 12: reassembler_dup
</span></span><span class="line"><span class="cl">11/29 Test #12: reassembler_dup ..................   Passed    0.05 sec
</span></span><span class="line"><span class="cl">      Start 13: reassembler_holes
</span></span><span class="line"><span class="cl">12/29 Test #13: reassembler_holes ................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 14: reassembler_overlapping
</span></span><span class="line"><span class="cl">13/29 Test #14: reassembler_overlapping ..........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 15: reassembler_win
</span></span><span class="line"><span class="cl">14/29 Test #15: reassembler_win ..................   Passed    0.15 sec
</span></span><span class="line"><span class="cl">      Start 16: wrapping_integers_cmp
</span></span><span class="line"><span class="cl">15/29 Test #16: wrapping_integers_cmp ............   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 17: wrapping_integers_wrap
</span></span><span class="line"><span class="cl">16/29 Test #17: wrapping_integers_wrap ...........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 18: wrapping_integers_unwrap
</span></span><span class="line"><span class="cl">17/29 Test #18: wrapping_integers_unwrap .........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 19: wrapping_integers_roundtrip
</span></span><span class="line"><span class="cl">18/29 Test #19: wrapping_integers_roundtrip ......   Passed    0.56 sec
</span></span><span class="line"><span class="cl">      Start 20: wrapping_integers_extra
</span></span><span class="line"><span class="cl">19/29 Test #20: wrapping_integers_extra ..........   Passed    0.12 sec
</span></span><span class="line"><span class="cl">      Start 21: recv_connect
</span></span><span class="line"><span class="cl">20/29 Test #21: recv_connect .....................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 22: recv_transmit
</span></span><span class="line"><span class="cl">21/29 Test #22: recv_transmit ....................   Passed    0.12 sec
</span></span><span class="line"><span class="cl">      Start 23: recv_window
</span></span><span class="line"><span class="cl">22/29 Test #23: recv_window ......................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 24: recv_reorder
</span></span><span class="line"><span class="cl">23/29 Test #24: recv_reorder .....................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 25: recv_reorder_more
</span></span><span class="line"><span class="cl">24/29 Test #25: recv_reorder_more ................   Passed    0.36 sec
</span></span><span class="line"><span class="cl">      Start 26: recv_close
</span></span><span class="line"><span class="cl">25/29 Test #26: recv_close .......................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 27: recv_special
</span></span><span class="line"><span class="cl">26/29 Test #27: recv_special .....................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 37: compile with optimization
</span></span><span class="line"><span class="cl">27/29 Test #37: compile with optimization ........   Passed    1.93 sec
</span></span><span class="line"><span class="cl">      Start 38: byte_stream_speed_test
</span></span><span class="line"><span class="cl">             ByteStream throughput: 18.15 Gbit/s
</span></span><span class="line"><span class="cl">28/29 Test #38: byte_stream_speed_test ...........   Passed    0.06 sec
</span></span><span class="line"><span class="cl">      Start 39: reassembler_speed_test
</span></span><span class="line"><span class="cl">             Reassembler throughput: 9.03 Gbit/s
</span></span><span class="line"><span class="cl">29/29 Test #39: reassembler_speed_test ...........   Passed    0.11 sec
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">100% tests passed, 0 tests failed out of 29
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Total Test time (real) =  23.60 sec
</span></span><span class="line"><span class="cl">Built target check2
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lab-3">Lab 3</h1>
<p>Lab 3 要求实现一个 sender，这里实现了 TCP 的超时重传和拥塞控制算法。需要实现如下几个方法：</p>
<ul>
<li><code>uint64_t TCPSender::sequence_numbers_in_flight() const</code>：返回待确认的字节数</li>
<li><code>uint64_t TCPSender::consecutive_retransmissions() const</code>：返回连续重传报文的数目</li>
<li><code>void TCPSender::push( const TransmitFunction&amp; transmit )</code>：从内存字节流中读取待发送数据，尽可能填满接收窗口</li>
<li><code>TCPSenderMessage TCPSender::make_empty_message() const</code>：产生一条不占用序号的空消息</li>
<li><code>void TCPSender::receive( const TCPReceiverMessage&amp; msg )</code>：接收来自接受者的确认消息，维护接收窗口的大小</li>
<li><code>void TCPSender::tick( uint64_t ms_since_last_tick, const TransmitFunction&amp; transmit )</code>：根据外部传入的时间判断是否需要重传和进行拥塞控制</li>
</ul>
<p>在实现 <code>push</code> 的过程中，有如下值得注意的地方：</p>
<ul>
<li>使用字段 <code>current_seq_</code> 记录当前需要发送的序号，第一次建立连接（current_seq_=0）时，需要将 <code>SYN</code> 字段设置为 <code>true</code>；</li>
<li><code>push</code> 方法仅用于首次发送消息，发送过的所有消息都保存在一个队列中，等待重传或者确认。在发送过 <code>FIN</code> 报文后，<code>push</code> 方法不应再发送任何消息，报文重传由 <code>tick</code> 方法负责；</li>
<li>原文提到，若接收窗口为 0，则在发送报文时应该视为 1；</li>
<li><code>push</code> 方法应该存在一个循环，用于处理接收窗口很大，待发送数据超过单个 TCP 包上限，需要发送多个包的情况；</li>
</ul>
<p>剩余部分跟着文档逻辑写，面向测试用例 debug。我在 <code>tcp_sender.hh</code> 中使用了如下成员变量：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="n">ByteStream</span> <span class="n">input_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">Wrap32</span> <span class="n">isn_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">initial_RTO_ms_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">current_time_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">ack_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">in_flight_cnt_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">expire_time_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">retrans_cnt_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">window_size_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">rto_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">current_seq_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">Wrap32</span> <span class="n">zero_point_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">deque</span><span class="o">&lt;</span><span class="n">TCPSenderMessage</span><span class="o">&gt;</span> <span class="n">outstanding_msg_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="kt">bool</span> <span class="n">is_fin_sent</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p><code>tcp_sender.cc</code> 各函数实现如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">  1
</span><span class="lnt">  2
</span><span class="lnt">  3
</span><span class="lnt">  4
</span><span class="lnt">  5
</span><span class="lnt">  6
</span><span class="lnt">  7
</span><span class="lnt">  8
</span><span class="lnt">  9
</span><span class="lnt"> 10
</span><span class="lnt"> 11
</span><span class="lnt"> 12
</span><span class="lnt"> 13
</span><span class="lnt"> 14
</span><span class="lnt"> 15
</span><span class="lnt"> 16
</span><span class="lnt"> 17
</span><span class="lnt"> 18
</span><span class="lnt"> 19
</span><span class="lnt"> 20
</span><span class="lnt"> 21
</span><span class="lnt"> 22
</span><span class="lnt"> 23
</span><span class="lnt"> 24
</span><span class="lnt"> 25
</span><span class="lnt"> 26
</span><span class="lnt"> 27
</span><span class="lnt"> 28
</span><span class="lnt"> 29
</span><span class="lnt"> 30
</span><span class="lnt"> 31
</span><span class="lnt"> 32
</span><span class="lnt"> 33
</span><span class="lnt"> 34
</span><span class="lnt"> 35
</span><span class="lnt"> 36
</span><span class="lnt"> 37
</span><span class="lnt"> 38
</span><span class="lnt"> 39
</span><span class="lnt"> 40
</span><span class="lnt"> 41
</span><span class="lnt"> 42
</span><span class="lnt"> 43
</span><span class="lnt"> 44
</span><span class="lnt"> 45
</span><span class="lnt"> 46
</span><span class="lnt"> 47
</span><span class="lnt"> 48
</span><span class="lnt"> 49
</span><span class="lnt"> 50
</span><span class="lnt"> 51
</span><span class="lnt"> 52
</span><span class="lnt"> 53
</span><span class="lnt"> 54
</span><span class="lnt"> 55
</span><span class="lnt"> 56
</span><span class="lnt"> 57
</span><span class="lnt"> 58
</span><span class="lnt"> 59
</span><span class="lnt"> 60
</span><span class="lnt"> 61
</span><span class="lnt"> 62
</span><span class="lnt"> 63
</span><span class="lnt"> 64
</span><span class="lnt"> 65
</span><span class="lnt"> 66
</span><span class="lnt"> 67
</span><span class="lnt"> 68
</span><span class="lnt"> 69
</span><span class="lnt"> 70
</span><span class="lnt"> 71
</span><span class="lnt"> 72
</span><span class="lnt"> 73
</span><span class="lnt"> 74
</span><span class="lnt"> 75
</span><span class="lnt"> 76
</span><span class="lnt"> 77
</span><span class="lnt"> 78
</span><span class="lnt"> 79
</span><span class="lnt"> 80
</span><span class="lnt"> 81
</span><span class="lnt"> 82
</span><span class="lnt"> 83
</span><span class="lnt"> 84
</span><span class="lnt"> 85
</span><span class="lnt"> 86
</span><span class="lnt"> 87
</span><span class="lnt"> 88
</span><span class="lnt"> 89
</span><span class="lnt"> 90
</span><span class="lnt"> 91
</span><span class="lnt"> 92
</span><span class="lnt"> 93
</span><span class="lnt"> 94
</span><span class="lnt"> 95
</span><span class="lnt"> 96
</span><span class="lnt"> 97
</span><span class="lnt"> 98
</span><span class="lnt"> 99
</span><span class="lnt">100
</span><span class="lnt">101
</span><span class="lnt">102
</span><span class="lnt">103
</span><span class="lnt">104
</span><span class="lnt">105
</span><span class="lnt">106
</span><span class="lnt">107
</span><span class="lnt">108
</span><span class="lnt">109
</span><span class="lnt">110
</span><span class="lnt">111
</span><span class="lnt">112
</span><span class="lnt">113
</span><span class="lnt">114
</span><span class="lnt">115
</span><span class="lnt">116
</span><span class="lnt">117
</span><span class="lnt">118
</span><span class="lnt">119
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;tcp_sender.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;tcp_config.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">sequence_numbers_in_flight</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">in_flight_cnt_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">consecutive_retransmissions</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">retrans_cnt_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">push</span><span class="p">(</span> <span class="k">const</span> <span class="n">TransmitFunction</span><span class="o">&amp;</span> <span class="n">transmit</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="kt">bool</span> <span class="n">window_zero</span> <span class="o">=</span> <span class="n">window_size_</span> <span class="o">==</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="kt">uint64_t</span> <span class="n">available_window</span>
</span></span><span class="line"><span class="cl">    <span class="o">=</span> <span class="p">(</span> <span class="n">window_size_</span> <span class="o">+</span> <span class="n">window_zero</span> <span class="p">)</span> <span class="o">&lt;</span> <span class="n">in_flight_cnt_</span> <span class="o">?</span> <span class="mi">0</span> <span class="o">:</span> <span class="n">window_size_</span> <span class="o">+</span> <span class="n">window_zero</span> <span class="o">-</span> <span class="n">in_flight_cnt_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">do</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="c1">// 先考虑SYN和RST，FIN要等到把buffer读空才能判断
</span></span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">is_fin_sent</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">      <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="kt">uint64_t</span> <span class="n">pay_load_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span> <span class="n">reader</span><span class="p">().</span><span class="n">bytes_buffered</span><span class="p">(),</span> <span class="n">TCPConfig</span><span class="o">::</span><span class="n">MAX_PAYLOAD_SIZE</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="kt">uint64_t</span> <span class="n">seq_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span> <span class="n">available_window</span><span class="p">,</span> <span class="n">pay_load_size</span> <span class="o">+</span> <span class="p">(</span> <span class="n">current_seq_</span> <span class="o">==</span> <span class="mi">0</span> <span class="p">)</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">pay_load_size</span> <span class="o">=</span> <span class="n">seq_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">TCPSenderMessage</span> <span class="n">msg</span> <span class="o">=</span> <span class="n">TCPSenderMessage</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">current_seq_</span> <span class="o">==</span> <span class="mi">0</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">msg</span><span class="p">.</span><span class="n">SYN</span> <span class="o">=</span> <span class="nb">true</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">pay_load_size</span><span class="o">--</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">reader</span><span class="p">().</span><span class="n">has_error</span><span class="p">()</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">msg</span><span class="p">.</span><span class="n">RST</span> <span class="o">=</span> <span class="nb">true</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">while</span> <span class="p">(</span> <span class="n">msg</span><span class="p">.</span><span class="n">payload</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">pay_load_size</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">string_view</span> <span class="n">front_view</span> <span class="o">=</span> <span class="n">reader</span><span class="p">().</span><span class="n">peek</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="kt">uint64_t</span> <span class="n">bytes_to_read</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span> <span class="n">front_view</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="n">pay_load_size</span> <span class="o">-</span> <span class="n">msg</span><span class="p">.</span><span class="n">payload</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">      <span class="n">msg</span><span class="p">.</span><span class="n">payload</span> <span class="o">+=</span> <span class="n">front_view</span><span class="p">.</span><span class="n">substr</span><span class="p">(</span> <span class="mi">0</span><span class="p">,</span> <span class="n">bytes_to_read</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">      <span class="n">input_</span><span class="p">.</span><span class="n">reader</span><span class="p">().</span><span class="n">pop</span><span class="p">(</span> <span class="n">bytes_to_read</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">reader</span><span class="p">().</span><span class="n">is_finished</span><span class="p">()</span> <span class="o">&amp;&amp;</span> <span class="n">seq_size</span> <span class="o">&lt;</span> <span class="n">available_window</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">msg</span><span class="p">.</span><span class="n">FIN</span> <span class="o">=</span> <span class="nb">true</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">seq_size</span><span class="o">++</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">is_fin_sent</span> <span class="o">=</span> <span class="nb">true</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">msg</span><span class="p">.</span><span class="n">sequence_length</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">      <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">msg</span><span class="p">.</span><span class="n">seqno</span> <span class="o">=</span> <span class="n">Wrap32</span><span class="o">::</span><span class="n">wrap</span><span class="p">(</span> <span class="n">current_seq_</span><span class="p">,</span> <span class="n">zero_point_</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">current_seq_</span> <span class="o">+=</span> <span class="n">msg</span><span class="p">.</span><span class="n">sequence_length</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="n">in_flight_cnt_</span> <span class="o">+=</span> <span class="n">msg</span><span class="p">.</span><span class="n">sequence_length</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="n">outstanding_msg_</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span> <span class="n">msg</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="n">transmit</span><span class="p">(</span> <span class="n">msg</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">expire_time_</span> <span class="o">==</span> <span class="n">UINT64_MAX</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">      <span class="n">expire_time_</span> <span class="o">=</span> <span class="n">current_time_</span> <span class="o">+</span> <span class="n">rto_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">available_window</span>
</span></span><span class="line"><span class="cl">      <span class="o">=</span> <span class="p">(</span> <span class="n">window_size_</span> <span class="o">+</span> <span class="n">window_zero</span> <span class="p">)</span> <span class="o">&lt;</span> <span class="n">in_flight_cnt_</span> <span class="o">?</span> <span class="mi">0</span> <span class="o">:</span> <span class="n">window_size_</span> <span class="o">+</span> <span class="n">window_zero</span> <span class="o">-</span> <span class="n">in_flight_cnt_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span> <span class="k">while</span> <span class="p">(</span> <span class="n">reader</span><span class="p">().</span><span class="n">bytes_buffered</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">available_window</span> <span class="o">!=</span> <span class="mi">0</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">TCPSenderMessage</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">make_empty_message</span><span class="p">()</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="p">{</span> <span class="n">Wrap32</span><span class="o">::</span><span class="n">wrap</span><span class="p">(</span> <span class="n">current_seq_</span><span class="p">,</span> <span class="n">zero_point_</span> <span class="p">),</span> <span class="nb">false</span><span class="p">,</span> <span class="n">string</span><span class="p">(),</span> <span class="nb">false</span><span class="p">,</span> <span class="n">reader</span><span class="p">().</span><span class="n">has_error</span><span class="p">()</span> <span class="p">};</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">receive</span><span class="p">(</span> <span class="k">const</span> <span class="n">TCPReceiverMessage</span><span class="o">&amp;</span> <span class="n">msg</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span> <span class="n">msg</span><span class="p">.</span><span class="n">ackno</span><span class="p">.</span><span class="n">has_value</span><span class="p">()</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="kt">uint64_t</span> <span class="n">ack_from_recv</span> <span class="o">=</span> <span class="n">unwarp</span><span class="p">(</span> <span class="n">msg</span><span class="p">.</span><span class="n">ackno</span><span class="p">.</span><span class="n">value</span><span class="p">()</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">ack_from_recv</span> <span class="o">&gt;</span> <span class="n">ack_</span> <span class="o">&amp;&amp;</span> <span class="n">ack_from_recv</span> <span class="o">&lt;=</span> <span class="n">current_seq_</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">ack_</span> <span class="o">=</span> <span class="n">ack_from_recv</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">rto_</span> <span class="o">=</span> <span class="n">initial_RTO_ms_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">expire_time_</span> <span class="o">=</span> <span class="n">current_time_</span> <span class="o">+</span> <span class="n">rto_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">retrans_cnt_</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="k">while</span> <span class="p">(</span> <span class="o">!</span><span class="n">outstanding_msg_</span><span class="p">.</span><span class="n">empty</span><span class="p">()</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="k">auto</span><span class="o">&amp;</span> <span class="n">front_msg</span> <span class="o">=</span> <span class="n">outstanding_msg_</span><span class="p">.</span><span class="n">front</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span> <span class="p">(</span> <span class="n">unwarp</span><span class="p">(</span> <span class="n">front_msg</span><span class="p">.</span><span class="n">seqno</span> <span class="p">)</span> <span class="o">+</span> <span class="n">front_msg</span><span class="p">.</span><span class="n">sequence_length</span><span class="p">()</span> <span class="o">&gt;</span> <span class="n">ack_</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">          <span class="k">break</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">        <span class="n">in_flight_cnt_</span> <span class="o">-=</span> <span class="n">front_msg</span><span class="p">.</span><span class="n">sequence_length</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="n">outstanding_msg_</span><span class="p">.</span><span class="n">pop_front</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span> <span class="p">(</span> <span class="n">outstanding_msg_</span><span class="p">.</span><span class="n">empty</span><span class="p">()</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">expire_time_</span> <span class="o">=</span> <span class="n">UINT64_MAX</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="n">window_size_</span> <span class="o">=</span> <span class="n">msg</span><span class="p">.</span><span class="n">window_size</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span> <span class="n">msg</span><span class="p">.</span><span class="n">RST</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">writer</span><span class="p">().</span><span class="n">set_error</span><span class="p">();</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">tick</span><span class="p">(</span> <span class="kt">uint64_t</span> <span class="n">ms_since_last_tick</span><span class="p">,</span> <span class="k">const</span> <span class="n">TransmitFunction</span><span class="o">&amp;</span> <span class="n">transmit</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="n">current_time_</span> <span class="o">+=</span> <span class="n">ms_since_last_tick</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span> <span class="p">(</span> <span class="n">expire_time_</span> <span class="o">!=</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="n">current_time_</span> <span class="o">&gt;=</span> <span class="n">expire_time_</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">transmit</span><span class="p">(</span> <span class="n">outstanding_msg_</span><span class="p">.</span><span class="n">front</span><span class="p">()</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="c1">//  auto msg = outstanding_msg_.front();
</span></span></span><span class="line"><span class="cl">    <span class="c1">//  outstanding_msg_.pop_front();
</span></span></span><span class="line"><span class="cl">    <span class="c1">//  outstanding_msg_.push_back(msg);
</span></span></span><span class="line"><span class="cl">    <span class="c1">//  transmit(msg);
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">if</span> <span class="p">(</span> <span class="n">window_size_</span> <span class="o">!=</span> <span class="mi">0</span> <span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="n">retrans_cnt_</span><span class="o">++</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">rto_</span> <span class="o">*=</span> <span class="mi">2</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">expire_time_</span> <span class="o">=</span> <span class="n">current_time_</span> <span class="o">+</span> <span class="n">rto_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="kt">uint64_t</span> <span class="n">TCPSender</span><span class="o">::</span><span class="n">unwarp</span><span class="p">(</span> <span class="k">const</span> <span class="n">Wrap32</span><span class="o">&amp;</span> <span class="n">seq</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">return</span> <span class="n">seq</span><span class="p">.</span><span class="n">unwrap</span><span class="p">(</span> <span class="n">zero_point_</span><span class="p">,</span> <span class="n">ack_</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>运行结果为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span><span class="lnt">64
</span><span class="lnt">65
</span><span class="lnt">66
</span><span class="lnt">67
</span><span class="lnt">68
</span><span class="lnt">69
</span><span class="lnt">70
</span><span class="lnt">71
</span><span class="lnt">72
</span><span class="lnt">73
</span><span class="lnt">74
</span><span class="lnt">75
</span><span class="lnt">76
</span><span class="lnt">77
</span><span class="lnt">78
</span><span class="lnt">79
</span><span class="lnt">80
</span><span class="lnt">81
</span><span class="lnt">82
</span><span class="lnt">83
</span><span class="lnt">84
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">-- Building in &#39;Debug&#39; mode.
</span></span><span class="line"><span class="cl">-- Configuring done (0.3s)
</span></span><span class="line"><span class="cl">-- Generating done (0.3s)
</span></span><span class="line"><span class="cl">-- Build files have been written to: /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">Test project /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">      Start  1: compile with bug-checkers
</span></span><span class="line"><span class="cl"> 1/36 Test  #1: compile with bug-checkers ........   Passed   40.66 sec
</span></span><span class="line"><span class="cl">      Start  3: byte_stream_basics
</span></span><span class="line"><span class="cl"> 2/36 Test  #3: byte_stream_basics ...............   Passed    0.02 sec
</span></span><span class="line"><span class="cl">      Start  4: byte_stream_capacity
</span></span><span class="line"><span class="cl"> 3/36 Test  #4: byte_stream_capacity .............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  5: byte_stream_one_write
</span></span><span class="line"><span class="cl"> 4/36 Test  #5: byte_stream_one_write ............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  6: byte_stream_two_writes
</span></span><span class="line"><span class="cl"> 5/36 Test  #6: byte_stream_two_writes ...........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start  7: byte_stream_many_writes
</span></span><span class="line"><span class="cl"> 6/36 Test  #7: byte_stream_many_writes ..........   Passed    0.05 sec
</span></span><span class="line"><span class="cl">      Start  8: byte_stream_stress_test
</span></span><span class="line"><span class="cl"> 7/36 Test  #8: byte_stream_stress_test ..........   Passed    0.05 sec
</span></span><span class="line"><span class="cl">      Start  9: reassembler_single
</span></span><span class="line"><span class="cl"> 8/36 Test  #9: reassembler_single ...............   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 10: reassembler_cap
</span></span><span class="line"><span class="cl"> 9/36 Test #10: reassembler_cap ..................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 11: reassembler_seq
</span></span><span class="line"><span class="cl">10/36 Test #11: reassembler_seq ..................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 12: reassembler_dup
</span></span><span class="line"><span class="cl">11/36 Test #12: reassembler_dup ..................   Passed    0.05 sec
</span></span><span class="line"><span class="cl">      Start 13: reassembler_holes
</span></span><span class="line"><span class="cl">12/36 Test #13: reassembler_holes ................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 14: reassembler_overlapping
</span></span><span class="line"><span class="cl">13/36 Test #14: reassembler_overlapping ..........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 15: reassembler_win
</span></span><span class="line"><span class="cl">14/36 Test #15: reassembler_win ..................   Passed    0.17 sec
</span></span><span class="line"><span class="cl">      Start 16: wrapping_integers_cmp
</span></span><span class="line"><span class="cl">15/36 Test #16: wrapping_integers_cmp ............   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 17: wrapping_integers_wrap
</span></span><span class="line"><span class="cl">16/36 Test #17: wrapping_integers_wrap ...........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 18: wrapping_integers_unwrap
</span></span><span class="line"><span class="cl">17/36 Test #18: wrapping_integers_unwrap .........   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 19: wrapping_integers_roundtrip
</span></span><span class="line"><span class="cl">18/36 Test #19: wrapping_integers_roundtrip ......   Passed    0.55 sec
</span></span><span class="line"><span class="cl">      Start 20: wrapping_integers_extra
</span></span><span class="line"><span class="cl">19/36 Test #20: wrapping_integers_extra ..........   Passed    0.12 sec
</span></span><span class="line"><span class="cl">      Start 21: recv_connect
</span></span><span class="line"><span class="cl">20/36 Test #21: recv_connect .....................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 22: recv_transmit
</span></span><span class="line"><span class="cl">21/36 Test #22: recv_transmit ....................   Passed    0.13 sec
</span></span><span class="line"><span class="cl">      Start 23: recv_window
</span></span><span class="line"><span class="cl">22/36 Test #23: recv_window ......................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">      Start 24: recv_reorder
</span></span><span class="line"><span class="cl">23/36 Test #24: recv_reorder .....................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 25: recv_reorder_more
</span></span><span class="line"><span class="cl">24/36 Test #25: recv_reorder_more ................   Passed    0.39 sec
</span></span><span class="line"><span class="cl">      Start 26: recv_close
</span></span><span class="line"><span class="cl">25/36 Test #26: recv_close .......................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 27: recv_special
</span></span><span class="line"><span class="cl">26/36 Test #27: recv_special .....................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 28: send_connect
</span></span><span class="line"><span class="cl">27/36 Test #28: send_connect .....................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 29: send_transmit
</span></span><span class="line"><span class="cl">28/36 Test #29: send_transmit ....................   Passed    0.18 sec
</span></span><span class="line"><span class="cl">      Start 30: send_retx
</span></span><span class="line"><span class="cl">29/36 Test #30: send_retx ........................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 31: send_window
</span></span><span class="line"><span class="cl">30/36 Test #31: send_window ......................   Passed    0.07 sec
</span></span><span class="line"><span class="cl">      Start 32: send_ack
</span></span><span class="line"><span class="cl">31/36 Test #32: send_ack .........................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 33: send_close
</span></span><span class="line"><span class="cl">32/36 Test #33: send_close .......................   Passed    0.04 sec
</span></span><span class="line"><span class="cl">      Start 34: send_extra
</span></span><span class="line"><span class="cl">33/36 Test #34: send_extra .......................   Passed    0.05 sec
</span></span><span class="line"><span class="cl">      Start 37: compile with optimization
</span></span><span class="line"><span class="cl">34/36 Test #37: compile with optimization ........   Passed    2.29 sec
</span></span><span class="line"><span class="cl">      Start 38: byte_stream_speed_test
</span></span><span class="line"><span class="cl">             ByteStream throughput: 19.14 Gbit/s
</span></span><span class="line"><span class="cl">35/36 Test #38: byte_stream_speed_test ...........   Passed    0.06 sec
</span></span><span class="line"><span class="cl">      Start 39: reassembler_speed_test
</span></span><span class="line"><span class="cl">             Reassembler throughput: 8.26 Gbit/s
</span></span><span class="line"><span class="cl">36/36 Test #39: reassembler_speed_test ...........   Passed    0.12 sec
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">100% tests passed, 0 tests failed out of 36
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Total Test time (real) =  45.37 sec
</span></span><span class="line"><span class="cl">Built target check3
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lab-4">Lab 4</h1>
<p>lab 4 的任务是使用我们之前写的 TCP 模块与外网进行通信，如果前面实现的都没问题，那么这里是不需要写代码的。按照文档指示执行，顺利通过测试，运行结果为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">Test project /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">    Start 1: compile with bug-checkers
</span></span><span class="line"><span class="cl">1/2 Test #1: compile with bug-checkers ........   Passed    0.11 sec
</span></span><span class="line"><span class="cl">    Start 2: t_webget
</span></span><span class="line"><span class="cl">2/2 Test #2: t_webget .........................   Passed    1.03 sec
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lab-5">Lab 5</h1>
<p>lab 5 实现了 ARP 协议，负责将 IP 地址转换为 MAC 地址，并发送来自传输层的报文。有如下细节值得注意：</p>
<ul>
<li>内存中需要维护一张 arp 表，每一个表项只有 30 秒的有效时间</li>
<li>相同目标 ip 的 arp 请求间隔为 5 秒钟</li>
<li>发送数据时，arp 表中没有对应记录，则先发出 arp 请求</li>
<li>收到 arp 回复报文后，需要将等待该记录的所有报文全部发出</li>
</ul>
<p>实现过程中，我新增了三个数据结构：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="c1">// 当前时间
</span></span></span><span class="line"><span class="cl"><span class="n">size_t</span> <span class="n">current_time_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">// 保存arp表
</span></span></span><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">unordered_map</span><span class="o">&lt;</span><span class="kt">uint32_t</span> <span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">EthernetAddress</span><span class="p">,</span> <span class="n">size_t</span><span class="o">&gt;&gt;</span> <span class="n">arp_table_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">// 等待arp请求的信号量队列
</span></span></span><span class="line"><span class="cl"><span class="n">std</span><span class="o">::</span><span class="n">unordered_map</span><span class="o">&lt;</span><span class="kt">uint32_t</span> <span class="p">,</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">queue</span><span class="o">&lt;</span><span class="n">EthernetFrame</span><span class="o">&gt;</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">optional</span><span class="o">&lt;</span><span class="n">size_t</span><span class="o">&gt;&gt;&gt;</span> <span class="n">frame_queue_</span><span class="p">;</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>arp 表每一条的有效时间只有 30 秒，因此每一行都要记录 ip 地址对应的 mac 地址和过期时间；在发送报文的方法中，如果目标 ip 的 mac 地址还不知道，则先把数据报插入到等待队列中，等待收到 arp 回复报文再发送报文（本质上是使用信号量实现同步关系）；此外，还要记录目标 ip 上次 arp 请求的时间，防止对同一个 ip 请求过于频繁。</p>
<p>实现 <code>send_datagram</code> 的逻辑为：首先填写数据帧中除目标 MAC 之外的字段，然后查询 arp 表，如果存在目标 ip 的有效条目，则填写 MAC 并发送；否则将待发送帧放入目标 ip 对应的队列，并发出 arp 请求。</p>
<p>实现 <code>recv_frame</code> 的逻辑为：首先根据 MAC 字段判断是否是发给自己的数据帧，只处理目标为自己或者广播地址的帧。然后根据类型字段对有效载荷解析，如果是 ip 包直接把解析包交付给上层队列；如果是 arp 包则根据协议头将更新 arp 表，如果收到的是 arp 请求报文，则构造 arp 回复报文回复自己的 mac，如果收到的是 arp 回复报文，则查看对应 ip 的待发送消息的队列，发送其中所有的消息。</p>
<p>详细实现的代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">  1
</span><span class="lnt">  2
</span><span class="lnt">  3
</span><span class="lnt">  4
</span><span class="lnt">  5
</span><span class="lnt">  6
</span><span class="lnt">  7
</span><span class="lnt">  8
</span><span class="lnt">  9
</span><span class="lnt"> 10
</span><span class="lnt"> 11
</span><span class="lnt"> 12
</span><span class="lnt"> 13
</span><span class="lnt"> 14
</span><span class="lnt"> 15
</span><span class="lnt"> 16
</span><span class="lnt"> 17
</span><span class="lnt"> 18
</span><span class="lnt"> 19
</span><span class="lnt"> 20
</span><span class="lnt"> 21
</span><span class="lnt"> 22
</span><span class="lnt"> 23
</span><span class="lnt"> 24
</span><span class="lnt"> 25
</span><span class="lnt"> 26
</span><span class="lnt"> 27
</span><span class="lnt"> 28
</span><span class="lnt"> 29
</span><span class="lnt"> 30
</span><span class="lnt"> 31
</span><span class="lnt"> 32
</span><span class="lnt"> 33
</span><span class="lnt"> 34
</span><span class="lnt"> 35
</span><span class="lnt"> 36
</span><span class="lnt"> 37
</span><span class="lnt"> 38
</span><span class="lnt"> 39
</span><span class="lnt"> 40
</span><span class="lnt"> 41
</span><span class="lnt"> 42
</span><span class="lnt"> 43
</span><span class="lnt"> 44
</span><span class="lnt"> 45
</span><span class="lnt"> 46
</span><span class="lnt"> 47
</span><span class="lnt"> 48
</span><span class="lnt"> 49
</span><span class="lnt"> 50
</span><span class="lnt"> 51
</span><span class="lnt"> 52
</span><span class="lnt"> 53
</span><span class="lnt"> 54
</span><span class="lnt"> 55
</span><span class="lnt"> 56
</span><span class="lnt"> 57
</span><span class="lnt"> 58
</span><span class="lnt"> 59
</span><span class="lnt"> 60
</span><span class="lnt"> 61
</span><span class="lnt"> 62
</span><span class="lnt"> 63
</span><span class="lnt"> 64
</span><span class="lnt"> 65
</span><span class="lnt"> 66
</span><span class="lnt"> 67
</span><span class="lnt"> 68
</span><span class="lnt"> 69
</span><span class="lnt"> 70
</span><span class="lnt"> 71
</span><span class="lnt"> 72
</span><span class="lnt"> 73
</span><span class="lnt"> 74
</span><span class="lnt"> 75
</span><span class="lnt"> 76
</span><span class="lnt"> 77
</span><span class="lnt"> 78
</span><span class="lnt"> 79
</span><span class="lnt"> 80
</span><span class="lnt"> 81
</span><span class="lnt"> 82
</span><span class="lnt"> 83
</span><span class="lnt"> 84
</span><span class="lnt"> 85
</span><span class="lnt"> 86
</span><span class="lnt"> 87
</span><span class="lnt"> 88
</span><span class="lnt"> 89
</span><span class="lnt"> 90
</span><span class="lnt"> 91
</span><span class="lnt"> 92
</span><span class="lnt"> 93
</span><span class="lnt"> 94
</span><span class="lnt"> 95
</span><span class="lnt"> 96
</span><span class="lnt"> 97
</span><span class="lnt"> 98
</span><span class="lnt"> 99
</span><span class="lnt">100
</span><span class="lnt">101
</span><span class="lnt">102
</span><span class="lnt">103
</span><span class="lnt">104
</span><span class="lnt">105
</span><span class="lnt">106
</span><span class="lnt">107
</span><span class="lnt">108
</span><span class="lnt">109
</span><span class="lnt">110
</span><span class="lnt">111
</span><span class="lnt">112
</span><span class="lnt">113
</span><span class="lnt">114
</span><span class="lnt">115
</span><span class="lnt">116
</span><span class="lnt">117
</span><span class="lnt">118
</span><span class="lnt">119
</span><span class="lnt">120
</span><span class="lnt">121
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&lt;iostream&gt;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;arp_message.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;exception.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl"><span class="cp">#include</span> <span class="cpf">&#34;network_interface.hh&#34;</span><span class="cp">
</span></span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">//! \param[in] ethernet_address Ethernet (what ARP calls &#34;hardware&#34;) address of the interface
</span></span></span><span class="line"><span class="cl"><span class="c1">//! \param[in] ip_address IP (what ARP calls &#34;protocol&#34;) address of the interface
</span></span></span><span class="line"><span class="cl"><span class="n">NetworkInterface</span><span class="o">::</span><span class="n">NetworkInterface</span><span class="p">(</span> <span class="n">string_view</span> <span class="n">name</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                    <span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">OutputPort</span><span class="o">&gt;</span> <span class="n">port</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                    <span class="k">const</span> <span class="n">EthernetAddress</span><span class="o">&amp;</span> <span class="n">ethernet_address</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                    <span class="k">const</span> <span class="n">Address</span><span class="o">&amp;</span> <span class="n">ip_address</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="o">:</span> <span class="n">name_</span><span class="p">(</span> <span class="n">name</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="p">,</span> <span class="n">port_</span><span class="p">(</span> <span class="n">notnull</span><span class="p">(</span> <span class="s">&#34;OutputPort&#34;</span><span class="p">,</span> <span class="n">move</span><span class="p">(</span> <span class="n">port</span> <span class="p">)</span> <span class="p">)</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="p">,</span> <span class="n">ethernet_address_</span><span class="p">(</span> <span class="n">ethernet_address</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="p">,</span> <span class="n">ip_address_</span><span class="p">(</span> <span class="n">ip_address</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="p">,</span> <span class="n">current_time_</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="p">,</span> <span class="n">arp_table_</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">  <span class="p">,</span> <span class="n">frame_queue_</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">cerr</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;DEBUG: Network interface has Ethernet address &#34;</span> <span class="o">&lt;&lt;</span> <span class="n">to_string</span><span class="p">(</span> <span class="n">ethernet_address</span> <span class="p">)</span> <span class="o">&lt;&lt;</span> <span class="s">&#34; and IP address &#34;</span>
</span></span><span class="line"><span class="cl">       <span class="o">&lt;&lt;</span> <span class="n">ip_address</span><span class="p">.</span><span class="n">ip</span><span class="p">()</span> <span class="o">&lt;&lt;</span> <span class="s">&#34;</span><span class="se">\n</span><span class="s">&#34;</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">//! \param[in] dgram the IPv4 datagram to be sent
</span></span></span><span class="line"><span class="cl"><span class="c1">//! \param[in] next_hop the IP address of the interface to send it to (typically a router or default gateway, but
</span></span></span><span class="line"><span class="cl"><span class="c1">//! may also be another host if directly connected to the same network as the destination) Note: the Address type
</span></span></span><span class="line"><span class="cl"><span class="c1">//! can be converted to a uint32_t (raw 32-bit IP address) by using the Address::ipv4_numeric() method.
</span></span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">NetworkInterface</span><span class="o">::</span><span class="n">send_datagram</span><span class="p">(</span> <span class="k">const</span> <span class="n">InternetDatagram</span><span class="o">&amp;</span> <span class="n">dgram</span><span class="p">,</span> <span class="k">const</span> <span class="n">Address</span><span class="o">&amp;</span> <span class="n">next_hop</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="n">EthernetFrame</span> <span class="n">messsage</span> <span class="o">=</span> <span class="n">EthernetFrame</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="k">const</span> <span class="kt">uint32_t</span> <span class="n">target_ip</span> <span class="o">=</span> <span class="n">next_hop</span><span class="p">.</span><span class="n">ipv4_numeric</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">messsage</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">src</span> <span class="o">=</span> <span class="n">ethernet_address_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">messsage</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">type</span> <span class="o">=</span> <span class="n">EthernetHeader</span><span class="o">::</span><span class="n">TYPE_IPv4</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">messsage</span><span class="p">.</span><span class="n">payload</span> <span class="o">=</span> <span class="n">serialize</span><span class="p">(</span><span class="n">dgram</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="o">!</span><span class="n">arp_table_</span><span class="p">.</span><span class="n">contains</span><span class="p">(</span><span class="n">target_ip</span><span class="p">)</span> <span class="o">||</span> <span class="n">arp_table_</span><span class="p">[</span><span class="n">target_ip</span><span class="p">].</span><span class="n">second</span> <span class="o">&lt;</span> <span class="n">current_time_</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="n">frame_queue_</span><span class="p">[</span><span class="n">target_ip</span><span class="p">].</span><span class="n">first</span><span class="p">.</span><span class="n">push</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">messsage</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">    <span class="n">EthernetFrame</span> <span class="n">arp_request_frame</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">send_arp_request</span><span class="p">(</span> <span class="n">target_ip</span><span class="p">,</span> <span class="n">arp_request_frame</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="n">messsage</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span> <span class="o">=</span> <span class="n">arp_table_</span><span class="p">[</span><span class="n">target_ip</span><span class="p">].</span><span class="n">first</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">    <span class="n">transmit</span><span class="p">(</span><span class="n">messsage</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">NetworkInterface</span><span class="o">::</span><span class="n">send_arp_request</span><span class="p">(</span> <span class="k">const</span> <span class="kt">uint32_t</span> <span class="n">target_ip</span><span class="p">,</span> <span class="n">EthernetFrame</span><span class="o">&amp;</span> <span class="n">arp_request_frame</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">frame_queue_</span><span class="p">.</span><span class="n">contains</span><span class="p">(</span><span class="n">target_ip</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="n">frame_queue_</span><span class="p">[</span><span class="n">target_ip</span><span class="p">].</span><span class="n">second</span><span class="p">.</span><span class="n">has_value</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">       <span class="o">&amp;&amp;</span> <span class="n">frame_queue_</span><span class="p">[</span><span class="n">target_ip</span><span class="p">].</span><span class="n">second</span> <span class="o">&gt;=</span> <span class="n">current_time_</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">type</span> <span class="o">=</span> <span class="n">EthernetHeader</span><span class="o">::</span><span class="n">TYPE_ARP</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span> <span class="o">=</span> <span class="n">ETHERNET_BROADCAST</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">src</span> <span class="o">=</span> <span class="n">ethernet_address_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">ARPMessage</span> <span class="n">arp_request_message</span> <span class="o">=</span> <span class="n">ARPMessage</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_message</span><span class="p">.</span><span class="n">sender_ethernet_address</span> <span class="o">=</span> <span class="n">ethernet_address_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_message</span><span class="p">.</span><span class="n">sender_ip_address</span> <span class="o">=</span> <span class="n">ip_address_</span><span class="p">.</span><span class="n">ipv4_numeric</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_message</span><span class="p">.</span><span class="n">opcode</span> <span class="o">=</span> <span class="n">ARPMessage</span><span class="o">::</span><span class="n">OPCODE_REQUEST</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_request_message</span><span class="p">.</span><span class="n">target_ip_address</span> <span class="o">=</span> <span class="n">target_ip</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="c1">//  arp_request_message.target_ethernet_address = ETHERNET_BROADCAST;
</span></span></span><span class="line"><span class="cl">  <span class="n">arp_request_frame</span><span class="p">.</span><span class="n">payload</span> <span class="o">=</span> <span class="n">serialize</span><span class="p">(</span><span class="n">arp_request_message</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">transmit</span><span class="p">(</span><span class="n">arp_request_frame</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="n">frame_queue_</span><span class="p">[</span><span class="n">target_ip</span><span class="p">].</span><span class="n">second</span> <span class="o">=</span> <span class="n">current_time_</span> <span class="o">+</span> <span class="mi">5000</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">//! \param[in] frame the incoming Ethernet frame
</span></span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">NetworkInterface</span><span class="o">::</span><span class="n">recv_frame</span><span class="p">(</span> <span class="k">const</span> <span class="n">EthernetFrame</span><span class="o">&amp;</span> <span class="n">frame</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">if</span><span class="p">(</span><span class="n">frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span> <span class="o">==</span> <span class="n">ethernet_address_</span> <span class="o">||</span> <span class="n">frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span> <span class="o">==</span> <span class="n">ETHERNET_BROADCAST</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">if</span><span class="p">(</span><span class="n">frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">EthernetHeader</span><span class="o">::</span><span class="n">TYPE_ARP</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">ARPMessage</span> <span class="n">message</span> <span class="o">=</span> <span class="n">ARPMessage</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">parse</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">frame</span><span class="p">.</span><span class="n">payload</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="n">message</span><span class="p">.</span><span class="n">target_ip_address</span> <span class="o">==</span> <span class="n">ip_address_</span><span class="p">.</span><span class="n">ipv4_numeric</span><span class="p">())</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">arp_table_</span><span class="p">[</span><span class="n">message</span><span class="p">.</span><span class="n">sender_ip_address</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_pair</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">sender_ethernet_address</span><span class="p">,</span> <span class="n">current_time_</span><span class="o">+</span><span class="mi">30000</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">message</span><span class="p">.</span><span class="n">opcode</span> <span class="o">==</span> <span class="n">ARPMessage</span><span class="o">::</span><span class="n">OPCODE_REQUEST</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">          <span class="n">EthernetFrame</span> <span class="n">response</span> <span class="o">=</span> <span class="n">EthernetFrame</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">          <span class="n">make_arp_response</span><span class="p">(</span> <span class="n">message</span><span class="p">,</span> <span class="n">response</span> <span class="p">);</span>
</span></span><span class="line"><span class="cl">          <span class="n">transmit</span><span class="p">(</span><span class="n">response</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">          <span class="c1">// 收到arp回复之后看对应ip有无待发送的消息
</span></span></span><span class="line"><span class="cl">          <span class="n">queue</span><span class="o">&lt;</span><span class="n">EthernetFrame</span><span class="o">&gt;&amp;</span> <span class="n">ip_queue</span> <span class="o">=</span> <span class="n">frame_queue_</span><span class="p">[</span><span class="n">message</span><span class="p">.</span><span class="n">sender_ip_address</span><span class="p">].</span><span class="n">first</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">          <span class="k">while</span> <span class="p">(</span><span class="o">!</span><span class="n">ip_queue</span><span class="p">.</span><span class="n">empty</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">            <span class="n">ip_queue</span><span class="p">.</span><span class="n">front</span><span class="p">().</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span> <span class="o">=</span> <span class="n">message</span><span class="p">.</span><span class="n">sender_ethernet_address</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">            <span class="n">transmit</span><span class="p">(</span><span class="n">ip_queue</span><span class="p">.</span><span class="n">front</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">            <span class="n">ip_queue</span><span class="p">.</span><span class="n">pop</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">          <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="p">}</span> <span class="k">else</span> <span class="nf">if</span><span class="p">(</span><span class="n">frame</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">EthernetHeader</span><span class="o">::</span><span class="n">TYPE_IPv4</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">      <span class="n">InternetDatagram</span> <span class="n">message</span> <span class="o">=</span> <span class="n">InternetDatagram</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">parse</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">frame</span><span class="p">.</span><span class="n">payload</span><span class="p">)){</span>
</span></span><span class="line"><span class="cl">        <span class="n">datagrams_received_</span><span class="p">.</span><span class="n">emplace</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">message</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">NetworkInterface</span><span class="o">::</span><span class="n">make_arp_response</span><span class="p">(</span> <span class="k">const</span> <span class="n">ARPMessage</span><span class="o">&amp;</span> <span class="n">message</span><span class="p">,</span> <span class="n">EthernetFrame</span><span class="o">&amp;</span> <span class="n">response</span> <span class="p">)</span> <span class="k">const</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="n">EthernetHeader</span><span class="o">&amp;</span> <span class="n">header</span> <span class="o">=</span> <span class="n">response</span><span class="p">.</span><span class="n">header</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">header</span><span class="p">.</span><span class="n">dst</span> <span class="o">=</span> <span class="n">message</span><span class="p">.</span><span class="n">sender_ethernet_address</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">header</span><span class="p">.</span><span class="n">src</span> <span class="o">=</span> <span class="n">ethernet_address_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">header</span><span class="p">.</span><span class="n">type</span> <span class="o">=</span> <span class="n">EthernetHeader</span><span class="o">::</span><span class="n">TYPE_ARP</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">ARPMessage</span> <span class="n">arp_response_message</span> <span class="o">=</span> <span class="n">ARPMessage</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_response_message</span><span class="p">.</span><span class="n">opcode</span> <span class="o">=</span> <span class="n">ARPMessage</span><span class="o">::</span><span class="n">OPCODE_REPLY</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_response_message</span><span class="p">.</span><span class="n">sender_ethernet_address</span> <span class="o">=</span> <span class="n">ethernet_address_</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_response_message</span><span class="p">.</span><span class="n">sender_ip_address</span> <span class="o">=</span> <span class="n">ip_address_</span><span class="p">.</span><span class="n">ipv4_numeric</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_response_message</span><span class="p">.</span><span class="n">target_ethernet_address</span> <span class="o">=</span> <span class="n">message</span><span class="p">.</span><span class="n">sender_ethernet_address</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">arp_response_message</span><span class="p">.</span><span class="n">target_ip_address</span> <span class="o">=</span> <span class="n">message</span><span class="p">.</span><span class="n">sender_ip_address</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">  <span class="n">response</span><span class="p">.</span><span class="n">payload</span> <span class="o">=</span> <span class="n">serialize</span><span class="p">(</span><span class="n">arp_response_message</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">  <span class="k">return</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1">//! \param[in] ms_since_last_tick the number of milliseconds since the last call to this method
</span></span></span><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">NetworkInterface</span><span class="o">::</span><span class="n">tick</span><span class="p">(</span> <span class="k">const</span> <span class="n">size_t</span> <span class="n">ms_since_last_tick</span> <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="n">current_time_</span> <span class="o">+=</span> <span class="n">ms_since_last_tick</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>运行结果为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">Test project /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">    Start  1: compile with bug-checkers
</span></span><span class="line"><span class="cl">1/2 Test  #1: compile with bug-checkers ........   Passed    8.79 sec
</span></span><span class="line"><span class="cl">    Start 35: net_interface
</span></span><span class="line"><span class="cl">2/2 Test #35: net_interface ....................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">100% tests passed, 0 tests failed out of 2
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Total Test time (real) =   8.80 sec
</span></span><span class="line"><span class="cl">Built target check5
</span></span></code></pre></td></tr></table>
</div>
</div><h1 id="lab-6">Lab 6</h1>
<p>在 lab 6，我们将实现路由转发。具体来说，需要在内存中维护一张路由表，并根据路由表做最长匹配，进而实现网络层的转发。</p>
<p>路由表比较理想的数据结构是前缀树，但建树的过程难免要用到智能指针，遂作罢。且文档中也说 O(n) 复杂度也是可接受的，因此我最终选择 <code>vector</code> 来保存路由表。路由表中，我没有保存前缀长度，而是将前缀长度转换为子网掩码，以方便后续匹配。</p>
<p>匹配使用与运算进行，当且仅当 <code>ip &amp; mask == prefix</code> 时，说明 <code>ip</code> 是匹配 <code>prefix</code> 的。一个 ip 可能匹配多个 prefix，可以根据 mask 的大小找到最长匹配。</p>
<p>找到最长匹配后，如果路由表项中还有下一跳，则转发到下一跳 ip；如果没有下一跳，说明直接交付给指定 ip 即可，即转发到目标 ip。</p>
<p><code>route()</code> 的实现如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C++" data-lang="C++"><span class="line"><span class="cl"><span class="kt">void</span> <span class="n">Router</span><span class="o">::</span><span class="n">route</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="c1">// Your code here.
</span></span></span><span class="line"><span class="cl">  <span class="k">for</span><span class="p">(</span> <span class="k">auto</span><span class="o">&amp;</span> <span class="nl">interface</span><span class="p">:</span> <span class="n">_interfaces</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">    <span class="k">auto</span><span class="o">&amp;</span> <span class="n">data_queue</span> <span class="o">=</span> <span class="n">interface</span><span class="o">-&gt;</span><span class="n">datagrams_received</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="k">while</span><span class="p">(</span><span class="o">!</span><span class="n">data_queue</span><span class="p">.</span><span class="n">empty</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">      <span class="n">InternetDatagram</span> <span class="o">&amp;</span><span class="n">data</span> <span class="o">=</span> <span class="n">data_queue</span><span class="p">.</span><span class="n">front</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">data</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">ttl</span> <span class="o">==</span> <span class="mi">0</span> <span class="o">||</span> <span class="n">data</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">ttl</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="n">data_queue</span><span class="p">.</span><span class="n">pop</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">        <span class="k">continue</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="n">data</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">ttl</span> <span class="o">-=</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">data</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">compute_checksum</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">      <span class="kt">uint32_t</span> <span class="n">ip</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="n">optional</span><span class="o">&lt;</span><span class="n">routing_item</span><span class="o">&gt;</span> <span class="n">best_match</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">      <span class="k">for</span><span class="p">(</span><span class="kt">uint32_t</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">routing_table_</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="n">i</span><span class="o">++</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">        <span class="k">auto</span><span class="o">&amp;</span> <span class="n">item</span> <span class="o">=</span> <span class="n">routing_table_</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">item</span><span class="p">.</span><span class="n">route_prefix_</span> <span class="o">==</span> <span class="p">(</span><span class="n">ip</span> <span class="o">&amp;</span> <span class="n">item</span><span class="p">.</span><span class="n">mask_</span><span class="p">)){</span>
</span></span><span class="line"><span class="cl">          <span class="k">if</span><span class="p">(</span><span class="o">!</span><span class="n">best_match</span><span class="p">.</span><span class="n">has_value</span><span class="p">()</span> <span class="o">||</span> <span class="n">best_match</span><span class="o">-&gt;</span><span class="n">mask_</span> <span class="o">&lt;</span> <span class="n">item</span><span class="p">.</span><span class="n">mask_</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">            <span class="n">best_match</span> <span class="o">=</span> <span class="n">item</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">          <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="k">if</span><span class="p">(</span><span class="n">best_match</span><span class="p">.</span><span class="n">has_value</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">        <span class="k">auto</span> <span class="o">&amp;</span><span class="n">next_interface</span> <span class="o">=</span> <span class="n">_interfaces</span><span class="p">.</span><span class="n">at</span><span class="p">(</span><span class="n">best_match</span><span class="o">-&gt;</span><span class="n">interface_num_</span><span class="p">);</span>
</span></span><span class="line"><span class="cl">        <span class="k">if</span><span class="p">(</span><span class="n">best_match</span><span class="o">-&gt;</span><span class="n">next_hop_</span><span class="p">.</span><span class="n">has_value</span><span class="p">()){</span>
</span></span><span class="line"><span class="cl">          <span class="n">next_interface</span><span class="o">-&gt;</span><span class="n">send_datagram</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">best_match</span><span class="o">-&gt;</span><span class="n">next_hop_</span><span class="p">.</span><span class="n">value</span><span class="p">());</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">          <span class="n">next_interface</span><span class="o">-&gt;</span><span class="n">send_datagram</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">Address</span><span class="o">::</span><span class="n">from_ipv4_numeric</span><span class="p">(</span><span class="n">data</span><span class="p">.</span><span class="n">header</span><span class="p">.</span><span class="n">dst</span><span class="p">));</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="n">data_queue</span><span class="p">.</span><span class="n">pop</span><span class="p">();</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>运行结果为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span><span class="lnt">8
</span><span class="lnt">9
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-text" data-lang="text"><span class="line"><span class="cl">Test project /home/zhouxin/projects/CS144/build
</span></span><span class="line"><span class="cl">    Start  1: compile with bug-checkers
</span></span><span class="line"><span class="cl">1/3 Test  #1: compile with bug-checkers ........   Passed    9.56 sec
</span></span><span class="line"><span class="cl">    Start 35: net_interface
</span></span><span class="line"><span class="cl">2/3 Test #35: net_interface ....................   Passed    0.02 sec
</span></span><span class="line"><span class="cl">    Start 36: router
</span></span><span class="line"><span class="cl">3/3 Test #36: router ...........................   Passed    0.01 sec
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">100% tests passed, 0 tests failed out of 3
</span></span></code></pre></td></tr></table>
</div>
</div>]]></content:encoded>
    </item>
    <item>
      <title>二分查找边界条件条件以及二分答案位置分析</title>
      <link>https://www.zhouxin.space/notes/boundary-of-binary-search/</link>
      <pubDate>Tue, 26 Mar 2024 12:18:00 +0800</pubDate>
      <guid>https://www.zhouxin.space/notes/boundary-of-binary-search/</guid>
      <description>&lt;h1 id=&#34;引入&#34;&gt;引入&lt;/h1&gt;
&lt;p&gt;二分查找是常见的针对有序数组的查找算法，其查找的时间复杂度为 $O(\log n)$。算法骨架很好理解，但笔者在实践过程中一直对一些细节问题模棱两可，例如 while 循环的边界条件、提前退出、二分答案的下标等。通过查询 STL 源码、文献等方式，笔者找到一个通用方案，解决二分查找的一系列细节问题。&lt;/p&gt;</description>
      <content:encoded><![CDATA[<h1 id="引入">引入</h1>
<p>二分查找是常见的针对有序数组的查找算法，其查找的时间复杂度为 $O(\log n)$。算法骨架很好理解，但笔者在实践过程中一直对一些细节问题模棱两可，例如 while 循环的边界条件、提前退出、二分答案的下标等。通过查询 STL 源码、文献等方式，笔者找到一个通用方案，解决二分查找的一系列细节问题。</p>
<h1 id="标准二分查找">标准二分查找</h1>
<p>从标准二分查找讲起，即给定严格递增数组 <code>num</code> 和目标值 <code>target</code>，返回 <code>target</code> 在 <code>num</code> 中的下标，若不存在，则返回 <code>-1</code>。一种可行的 C 语言代码为：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">binary_search</span><span class="p">(</span><span class="kt">int</span> <span class="o">*</span><span class="n">num</span><span class="p">,</span> <span class="kt">int</span> <span class="o">*</span><span class="n">numsSize</span><span class="p">,</span> <span class="kt">int</span> <span class="n">target</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">left</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">right</span> <span class="o">=</span> <span class="n">numsSize</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">while</span><span class="p">(</span><span class="n">right</span> <span class="o">&gt;=</span> <span class="n">left</span><span class="p">){</span> <span class="c1">// 循环条件
</span></span></span><span class="line"><span class="cl">		<span class="kt">int</span> <span class="n">mid</span> <span class="o">=</span> <span class="p">(</span><span class="n">left</span><span class="o">+</span><span class="n">right</span><span class="p">)</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">right</span> <span class="o">=</span> <span class="n">mid</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> <span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">left</span> <span class="o">=</span> <span class="n">mid</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> 
</span></span><span class="line"><span class="cl">			<span class="k">return</span> <span class="n">mid</span><span class="p">;</span> <span class="c1">// 提前退出条件
</span></span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在查找过程中使用闭区间[left, right] 表示 <code>target</code> 可能存在的位置，那么循环退出只有两种情况：找到了 <code>target</code> 或者区间长度为 0，分别代码中的提前退出条件和循环条件。其中，循环条件根据区间的开闭性质而有所不同，例如若使用左闭右开区间来表示 <code>target</code> 的位置，那么区间长度为 0 表示为 <code>right == left+1</code>，即循环条件为 <code>right != left+1</code>。<br>
根据上面分析，如果找到了 <code>target</code>，一定会通过提前退出直接返回下标 <code>mid</code>，因此如果通过循环条件正常退出循环，说明目标值在数组中不存在，直接返回 -1。</p>
<h1 id="二分查找左边界">二分查找左边界</h1>
<p>二分查找左边界问题定义为：给定非严格递增数组 <code>nums</code> 和目标值 <code>target</code>，返回向 <code>nums</code> 中插入 <code>target</code> 的最小下标。例如，<code>nums = {1,2,2,3}</code>，<code>target = 2</code>，查找得到的左边界应该为 1。<br>
与标准二分查找类似，使用闭区间[left, right] 表示目标下标所在的区间。为了找到 <code>target</code>，我们可以通过不断压缩 <code>right</code> 的位置来逼近目标。怎么压缩呢？当 <code>nums[mid] != target</code> 时候，压缩方案与标准二分一致；当 <code>nums[mid] == target</code> 时，则是之前没有碰到的情况。以下给出一种解决方案：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">lower_bound</span><span class="p">(</span><span class="kt">int</span> <span class="o">*</span><span class="n">num</span><span class="p">,</span> <span class="kt">int</span> <span class="o">*</span><span class="n">numsSize</span><span class="p">,</span> <span class="kt">int</span> <span class="n">target</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">left</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">right</span> <span class="o">=</span> <span class="n">numsSize</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">while</span><span class="p">(</span><span class="n">right</span> <span class="o">&gt;</span> <span class="n">left</span><span class="p">)</span> <span class="c1">// 循环条件
</span></span></span><span class="line"><span class="cl">	<span class="p">{</span>
</span></span><span class="line"><span class="cl">		<span class="kt">int</span> <span class="n">mid</span> <span class="o">=</span> <span class="err">（</span><span class="n">left</span><span class="o">+</span><span class="n">mid</span><span class="p">)</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">right</span> <span class="o">=</span> <span class="n">mid</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> <span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">right</span> <span class="o">=</span> <span class="n">mid</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> 
</span></span><span class="line"><span class="cl">			<span class="n">left</span> <span class="o">=</span> <span class="n">mid</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">left</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="cm">/* 如果target不存在需要返回-1
</span></span></span><span class="line"><span class="cl"><span class="cm">	** if(left == numsSize || nums[left]!=target)
</span></span></span><span class="line"><span class="cl"><span class="cm">	**     return -1;
</span></span></span><span class="line"><span class="cl"><span class="cm">	** else
</span></span></span><span class="line"><span class="cl"><span class="cm">	**     return left
</span></span></span><span class="line"><span class="cl"><span class="cm">	*/</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>当 <code>nums[mid] == target</code> 时，上述方案将 <code>right</code> 更新为 <code>mid</code>，对比标准二分方案，可以发现循环条件不再取等了，并且也不存在提前退出的条件。这是由于我们缩写的 <code>lower_bound</code> 函数返回的 <code>target</code> 插入 <code>nums</code> 的下标，因此当区间长度为 1 时，就找到了返回值，可以停止循环。<br>
有些问题可能会要求当 <code>target</code> 不在 <code>nums</code> 中时，返回 -1，那么在循环结束后，需要检查 <code>nums[left]</code> 是否为目标值。值得注意的是，<code>target</code> 可能插入的位置在是 <code>nums</code> 的最后一位，因此需要检查是否越界。</p>
<h1 id="二分查找右边界">二分查找右边界</h1>
<p>二分查找左边界问题定义为：给定非严格递增数组 <code>nums</code> 和目标值 <code>target</code>，返回向 <code>nums</code> 中插入 <code>target</code> 的最大下标。例如，<code>nums = {1,2,2,3}</code>，<code>target = 3</code>，查找得到的右边界应该为 2。<br>
如果参照 <a href=".md#%E4%BA%8C%E5%88%86%E6%9F%A5%E6%89%BE%E5%B7%A6%E8%BE%B9%E7%95%8C">二分查找左边界</a> 中的思想，不断压缩左边界，可以写出一个死循环的有边界查找方案：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">upper_bound</span><span class="p">(</span><span class="kt">int</span> <span class="o">*</span><span class="n">num</span><span class="p">,</span> <span class="kt">int</span> <span class="o">*</span><span class="n">numsSize</span><span class="p">,</span> <span class="kt">int</span> <span class="n">target</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">left</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">right</span> <span class="o">=</span> <span class="n">numsSize</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">while</span><span class="p">(</span><span class="n">right</span> <span class="o">&gt;</span> <span class="n">left</span><span class="p">)</span> 
</span></span><span class="line"><span class="cl">	<span class="p">{</span>
</span></span><span class="line"><span class="cl">		<span class="kt">int</span> <span class="n">mid</span> <span class="o">=</span> <span class="err">（</span><span class="n">left</span><span class="o">+</span><span class="n">mid</span><span class="p">)</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">left</span> <span class="o">=</span> <span class="n">mid</span><span class="p">;</span> <span class="c1">//压缩左边界
</span></span></span><span class="line"><span class="cl">		<span class="k">else</span> <span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">right</span> <span class="o">=</span> <span class="n">mid</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> 
</span></span><span class="line"><span class="cl">			<span class="n">left</span> <span class="o">=</span> <span class="n">mid</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">right</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>为什么会死循环呢？这是在某些情况下 <code>left</code> 值和 <code>mid</code> 值相等并且 <code>nums[mid] == target</code>，因此 <code>left</code> 值就一直得不到更新，造成了死循环。为了解决这个问题，我们可以通过让 <code>left = mid+1</code> 保证每次对 <code>left</code> 的值的更新都是有效的。<br>
但上面的操作又引入了一个新的问题：<code>mid</code> 循环退出时，<code>mid</code> 可能指向第一个比 <code>target</code> 大的元素，也可能指向 <code>target</code>，而 <code>right</code> 又大于等于 <code>mid</code>，故 <code>right</code> 的指向是不确定的。既然如此，干脆直接让 <code>right</code> 指向第一个比 <code>target</code> 的元素，最后返回 <code>right-1</code> 即可。那么在上一段修改的基础上，对于 <code>nums[mid]&gt;target</code> 情况，<code>right</code> 更新为 <code>mid</code> 即可。<br>
基于上述思想，二分查找右边界的方案如下：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-C" data-lang="C"><span class="line"><span class="cl"><span class="kt">int</span> <span class="nf">upper_bound</span><span class="p">(</span><span class="kt">int</span> <span class="o">*</span><span class="n">num</span><span class="p">,</span> <span class="kt">int</span> <span class="o">*</span><span class="n">numsSize</span><span class="p">,</span> <span class="kt">int</span> <span class="n">target</span><span class="p">){</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">left</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="kt">int</span> <span class="n">right</span> <span class="o">=</span> <span class="n">numsSize</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="k">while</span><span class="p">(</span><span class="n">right</span> <span class="o">&gt;</span> <span class="n">left</span><span class="p">)</span> <span class="c1">// 循环条件
</span></span></span><span class="line"><span class="cl">	<span class="p">{</span>
</span></span><span class="line"><span class="cl">		<span class="kt">int</span> <span class="n">mid</span> <span class="o">=</span> <span class="err">（</span><span class="n">left</span><span class="o">+</span><span class="n">mid</span><span class="p">)</span><span class="o">/</span><span class="mi">2</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">left</span> <span class="o">=</span> <span class="n">mid</span><span class="o">+</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> <span class="k">if</span><span class="p">(</span><span class="n">nums</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">			<span class="n">right</span> <span class="o">=</span> <span class="n">mid</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">		<span class="k">else</span> 
</span></span><span class="line"><span class="cl">			<span class="n">left</span> <span class="o">=</span> <span class="n">mid</span><span class="p">;</span>
</span></span><span class="line"><span class="cl">	<span class="p">}</span>
</span></span><span class="line"><span class="cl">	<span class="k">return</span> <span class="n">right</span><span class="o">-</span><span class="mi">1</span><span class="p">;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></td></tr></table>
</div>
</div>]]></content:encoded>
    </item>
  </channel>
</rss>
