0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

TVM學(xué)習(xí)(八)pass總結(jié)

djelje ? 來源:djelje ? 作者:djelje ? 2022-08-02 09:43 ? 次閱讀

什么是pass?

Pass是TVM中基于relay IR進行的優(yōu)化,目的是去除冗余算子,進行硬件友好的算子轉(zhuǎn)換,最終能夠提高硬件運行效率。由tensorflow深度學(xué)習(xí)框架生成的圖機構(gòu)中,含有很多可以優(yōu)化的算子,比如expand_dim,len等,其實在編譯階段完全可以優(yōu)化掉,從而能夠減少硬件的計算,以及避免出現(xiàn)硬件不支持的算子。

TVM中在include/tvm/ir/transform.h中對pass進行了抽象,主要包括PassContext,PassInfo,Pass,以及Sequential。其中PassContext包含了pass執(zhí)行依賴的一些參數(shù),比如優(yōu)化level,analysis report等。PassInfo是一個用于記錄pass信息的類,包括pass的opt-level,名稱等。和PassContext的區(qū)別是PassContext是pass執(zhí)行所需要獲取的條件。Pass就是執(zhí)行pass的主體,主要就是pass的函數(shù)。比如RemoveUnusedFunctions就是執(zhí)行pass的一個主體函數(shù),目的就是去除冗余算子。Sequential是一個container,裝載所有pass。

一些pass

01. RemoveUnusedFunctions

位于src/relay/backend/vm/removed_unused_funcs.cc中,顧名思義就是去除relay IR中的冗余函數(shù)。通過從main函數(shù)開始遍歷,如果一個函數(shù)體沒有引用其它函數(shù),而同時又沒有被其它函數(shù)調(diào)用,即從relay圖上看是一個孤立算子,那么就從IRModule中刪除。

 void VisitExpr_(const FunctionNode* func_node) final {
    auto func = GetRef(func_node);
    if (visiting_.find(func) == visiting_.end()) {
      visiting_.insert(func);
      for (auto param : func_node->params) {
        ExprVisitor::VisitExpr(param);
      }
      ExprVisitor::VisitExpr(func_node-> body);
    }
  }

02. ToBasicBlockNormalForm

函數(shù)在文件src/relay/trnaforms/to_basic_block_normal_from.cc中。通過遍歷IRModule中的每個function,將每個function轉(zhuǎn)換為基本塊形式。轉(zhuǎn)換函數(shù)是ToBasicBlockNormalFormAux。這個函數(shù)包括兩個步驟:一是找到基本塊(basic block)的邊界,TVM中對邊界進行了一步抽象,判斷每個expr是否屬于同一個scope,如果scope相同那么就可以將這些表達式放在一個基本塊中;第二步根據(jù)每個表達式所屬的scope將表達式歸屬到一個基本塊中。

Expr ToBasicBlockNormalFormAux(const Expr& e) {
  // calculate all the dependency between nodes.
  support::Arena arena;
  DependencyGraph dg = DependencyGraph::Create(&arena, e);
  /* The scope of the whole expr is global.
   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
   * We also record the set of expressions whose scope is lifted.
   */
  std::pair scopes = CalcScope(dg);
  return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}

DependencyGraph是一個表達式相互依賴的圖結(jié)構(gòu),通過遍歷圖中每個節(jié)點,找到每個節(jié)點的scope。CalcScope在文件src/relay/transforms/to_a_normal_from.cc中。這個函數(shù)中重點關(guān)注以下代碼:

…
        s = LCA(s, expr_scope.at(iit->value));
…
    if (n->new_scope) {
      auto child_scope = std::make_shared(s);
      expr_scope.insert({n, child_scope});
    } else {
      expr_scope.insert({n, s});
}

LCA是獲得當前節(jié)點的父節(jié)點的scope的LCA(least common ancestor),然后將這個scope作為這個節(jié)點的scope。了解基本塊原理的都知道,尋找基本塊首先要找到首指令的位置,然后一個首指令到下一個首指令之間的指令就屬于一個基本塊。而首指令就是那些具有條件和無條件跳轉(zhuǎn)的指令。在TVM中通過new_scope來標記這些節(jié)點,比如Ifnode,F(xiàn)unctionNode,LetNode在建立dependency圖的時候,這些節(jié)點就被標記為new_scope。這樣就建立了dependency節(jié)點到scope節(jié)點的對應(yīng)map。同時scope節(jié)點也被建立起樹結(jié)構(gòu)。

接下來就是建立Fill類,這個類中包含了dependency圖以及scope的信息,通過其函數(shù)ToBasicBlockNormalForm實現(xiàn)基本塊轉(zhuǎn)換。它的基本邏輯通過VisitExpr函數(shù)遍歷dependency節(jié)點,將具有相同scope的節(jié)點壓入到同一個let_list中。Let_list文檔中是這樣解釋的:

/*!
 * \file let_list.h
 * \brief LetList record let binding and insert let expression implicitly.
 *  using it, one can treat AST as value instead of expression,
 *  and pass them around freely without fear of AST explosion (or effect duplication).
 *  for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
 *  if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
 *  the AST will contain 2 'a', as b and c are now variables.

Let_list使得抽象語法樹簡潔化,不會因為變量的復(fù)制導(dǎo)致樹的爆炸。具有相同的scope的expr被約束到相同的let_list中,用一個var來表達,這樣就將表達式轉(zhuǎn)化為var的形式。一個var也就對應(yīng)了一個基本塊。

03. Legalize

Legalize是實現(xiàn)等價函數(shù)的轉(zhuǎn)換。主要代碼在src/relay/transforms/legalize.cc中。主函數(shù)是:

Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
  auto rewriter = Legalizer(legalize_map_attr_name);
  return PostOrderRewrite(expr, &rewriter);
}

在legalize.cc文件中定義了一個繼承了ExprRewriter的類,在這個類中實現(xiàn)了對function的替換。我們追蹤一下調(diào)用的過程。PostOrderRewrite在文件src/relay/ir/expr_functor.cc中。首先建立一個PostOrderRewriter類,然后訪問每個節(jié)點。在訪問節(jié)點過程中調(diào)用了ExpandDataFlow函數(shù),看一下這個函數(shù)的描述:

*
 * ExpandDataflow manually manages a stack and performs DFS to determine the processing
 * order of nodes in an input graph.
 *
 * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
 * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
 * and continues iteratively to process the top of the stack. When it finds a node that doesn't
 * match the dataflow types, or a node who's inputs have all been processed, it visits the current
 * leaf via fvisit_leaf.
 *
 * This function should be used internally to other classes to implement mixed-mode traversals. The
 * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
 * hits a non-dataflow node.
 *
 * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
 */

主要目的是有區(qū)別的去處理graph中的節(jié)點,如果fcheck_visited已經(jīng)確定該節(jié)點處理過或者不需要處理,就跳過,通過fvisit_leaf繼續(xù)訪問下一個節(jié)點。而在VisitLeaf函數(shù)中就調(diào)用了legalizer類中的rewrite_函數(shù)實現(xiàn)了legalize功能。在Rewrite_中,通過映射表legalize_map_attr_name實現(xiàn)函數(shù)的等價轉(zhuǎn)換。

04. SimplifyInference

實現(xiàn)對batch normalization, layer normalization, instance normalization, group normalization, L2 normalization算子的分解,這樣做的目的是可以在之后的優(yōu)化中,將這些算子融合到其它算子上,減少計算量。代碼在src/relay/transforms/simplify_inference.cc中。文件中定義了一個InferenceSimplifier類來處理這個問題??匆幌逻@幾個normalization的公式:

1 BN:

pYYBAGGYIDKAMYXkAALAFPdMTWI678.png

2 LN:獲得均值和方差是基于同一層不同神經(jīng)元的數(shù)據(jù)。歸一化公式相同。

3 GN: 將每個輸入樣本沿著通道進行分組,在每個組內(nèi)進行歸一化。

4 IN:對每個通道的數(shù)據(jù)進行歸一化。

來看一下bacth normalization的處理代碼:

Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
                            Expr moving_var, Type tdata) {
  auto ttype = tdata.as();
  CHECK(ttype);
  const auto param = attrs.as< BatchNormAttrs>();
  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon));
  Expr var_add_eps = Add(moving_var, epsilon);
  Expr sqrt_var = Sqrt(var_add_eps);
  Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);


  if (param->scale) {
    scale = Multiply(scale, gamma);
  }
  Expr neg_mean = Negative(moving_mean);
  Expr shift = Multiply(neg_mean, scale);
  if (param->center) {
    shift = Add(shift, beta);
  }


  auto ndim = ttype->shape.size();
  int axis = (param->axis <  0) ? param->axis + ndim : param->axis;
  scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
  shift = ExpandBiasToMatchAxis(shift, ndim, {axis});


  Expr out = Multiply(data, scale);
  out = Add(out, shift);
  return out;
}

可以看到就是將batch norm算子分解成最基本的加減乘除算子。

05. EliminateCommonSubexpr

顧名思義,這個pass的目的是消除公共子表達式。公共子表達式類似這種:

a=b+c

d=b+c

兩個表達式具有相同的op,同時又有相同的args,而且args的順序也一樣。那么就可以用一個表達式替換。

這個pass的實現(xiàn)在文件src/relay/transforms/eliminate_common_subexpr.cc中。TVM定義了類CommonSubexprEliminator來處理。重載函數(shù)Rewrite_實現(xiàn)了對expr的遍歷和重寫操作。

 Expr Rewrite_(const CallNode* call, const Expr& post) final {
…
    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef< Op>(op), false)) {
      return new_expr;
    }
    if (fskip_ != nullptr && fskip_(new_expr)) {
      return new_expr;
    }


    auto it = expr_map_.find(new_call->op);
    if (it != expr_map_.end()) {
      for (const Expr& candidate_expr : it->second) {
        if (const CallNode* candidate = candidate_expr.as< CallNode>()) {
          bool is_equivalent = true;
          if (!attrs_equal(new_call->attrs, candidate->attrs)) {
            continue;
          }
          for (size_t i = 0; i <  new_call->args.size(); i++) {
            if (!new_call->args[i].same_as(candidate->args[i]) &&
                !IsEqualScalar(new_call->args[i], candidate->args[i])) {
              is_equivalent = false;
              break;
            }
          }
          if (!is_equivalent) continue;
          return GetRef(candidate);
        }
      }
    }
    expr_map_[new_call->op].push_back(new_expr);
    return new_expr;
  }

使用一個expr_map_映射記錄已經(jīng)遍歷過的具有相同op的expr,之后每次遇到相同的op都會對已經(jīng)記錄的expr進行匹配,匹配包括attrs以及args,如果二者都一樣的話,證明就是公共子表達式。

沒有看過的pass

以上是實現(xiàn)相對簡單的pass,TVM中還實現(xiàn)了其它很多pass,就沒有一一去讀代碼了。以后看需要再去讀吧?,F(xiàn)在做一些羅列:

1 SimplifyExpr

簡化一些表達式,具體如何進行簡化需要讀代碼了。

2 CombineParallelConv2D

合并多分支并行的conv2d運算,理解是對多個batch的conv2d進行合并。

3 CombineParalleleDense

將多個batch的dense操作合并為一個batch_matmul操作。

4 CombineParallelBatchMatmul

對多個并行的batch_mamul再進行合并。

這幾個combine操作可能是針對GPU器件的一個多數(shù)據(jù)并行性的優(yōu)化。

5 FoldConstant

典型的一個常量合并優(yōu)化。

6 FoldScaleAxis

包含了ForwardFoldScaleAxis和backwardFoldScaleAxis,主要是將scale參數(shù)合并到conv/dense操作的權(quán)重參數(shù)中。

7 CanonicalizeCast

官方解釋是: Canonicalize cast expressions to make operator fusion more efficient。理解是對一些cast操作規(guī)范化,就是讓復(fù)雜的cast操作可以更簡潔。

8 CanonicalizeOps

規(guī)范化一些算子,比如bias_add能夠被表示為expand_dims和broadcast_add操作。

審核編輯 黃昊宇

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 優(yōu)化
    +關(guān)注

    關(guān)注

    0

    文章

    219

    瀏覽量

    23823
  • TVM
    TVM
    +關(guān)注

    關(guān)注

    0

    文章

    19

    瀏覽量

    3642
收藏 人收藏

    評論

    相關(guān)推薦

    TVM主要的編譯過程解析

    `  TVM主要的編譯過程如下圖:    Import:將tensorflow,onnx,pytorch等構(gòu)建的深度學(xué)習(xí)模型導(dǎo)入,轉(zhuǎn)化成TVM的中間層表示IR?! ower:將高層IR表示轉(zhuǎn)化成
    發(fā)表于 01-07 16:59

    TVM整體結(jié)構(gòu),TVM代碼的基本構(gòu)成

    圖:    Frontend:這個就是將來自不同深度學(xué)習(xí)框架中的神經(jīng)網(wǎng)絡(luò)轉(zhuǎn)化成TVM自己的IR表示。神經(jīng)網(wǎng)絡(luò)模型的輸入是protoBuf文件,比如在tensorflow中就是pbtxt文件,這個文件中
    發(fā)表于 01-07 17:21

    TVM中將計算算符有哪幾種

    TVM中將計算算符分成四種
    發(fā)表于 01-26 06:34

    TVM的編譯流程

    TVM主要的編譯過程
    發(fā)表于 02-23 07:43

    SOPC Builder/Nios 學(xué)習(xí)經(jīng)驗總結(jié)

    SOPC Builder/Nios 學(xué)習(xí)經(jīng)驗總結(jié)
    發(fā)表于 07-22 15:32 ?0次下載
    SOPC Builder/Nios <b class='flag-5'>學(xué)習(xí)</b>經(jīng)驗<b class='flag-5'>總結(jié)</b>

    FPGA學(xué)習(xí)總結(jié)[經(jīng)典推薦]

    單片機(Microcontrollers)學(xué)習(xí),F(xiàn)PGA學(xué)習(xí)總結(jié)[經(jīng)典推薦],感興趣的小伙伴可以瞧一瞧。
    發(fā)表于 11-03 15:15 ?155次下載

    ARM寄存器學(xué)習(xí)總結(jié)

    ARM寄存器學(xué)習(xí)總結(jié)
    發(fā)表于 01-04 15:10 ?0次下載

    TVM用于移動端常見的ARM GPU,提高移動設(shè)備對深度學(xué)習(xí)的支持能力

    的壓力。 TVM是一個端到端的IR堆棧,它可以解決學(xué)習(xí)過程中的資源分配問題,從而輕松實現(xiàn)硬件優(yōu)化。在這篇文章中,我們將展示如何用TVM/NNVM為ARM Mali GPU生成高效kernel
    的頭像 發(fā)表于 01-18 13:38 ?1.1w次閱讀

    什么是波場虛擬機TVM

    TVM與現(xiàn)有的開發(fā)生態(tài)系統(tǒng)無縫連接,并支持 DPoS。 TVM最初與 EVM 環(huán)境兼容,因此開發(fā)人員可以使用Solidity和其他語言在 Remix 環(huán)境中開發(fā),調(diào)試和編譯智能合約,而不是學(xué)習(xí)
    發(fā)表于 05-15 09:46 ?3164次閱讀
    什么是波場虛擬機<b class='flag-5'>TVM</b>

    Linux的基礎(chǔ)學(xué)習(xí)筆記資料總結(jié)

    本文檔的主要內(nèi)容詳細介紹的是Linux的基礎(chǔ)學(xué)習(xí)筆記資料總結(jié)包括了:一、 常用命令,二、 磁盤管理,三、 用戶管理,四、 文件權(quán)限,五、 目錄結(jié)構(gòu),六、 軟件安裝,七、 時間管理,、 啟動引導(dǎo),九
    發(fā)表于 11-13 08:00 ?4次下載

    TVM的編譯流程是什么

    TVM主要的編譯過程如下圖:Import:將tensorflow,onnx,pytorch等構(gòu)建的深度學(xué)習(xí)模型導(dǎo)入,轉(zhuǎn)化成TVM的中間層表示IR。Lower:將高層IR表示轉(zhuǎn)化成低階TIR表示。Codegen:內(nèi)存分配和硬件可執(zhí)
    的頭像 發(fā)表于 02-08 14:51 ?1582次閱讀
    <b class='flag-5'>TVM</b>的編譯流程是什么

    TVM學(xué)習(xí)(三)編譯流程

    TVM主要的編譯過程如下圖:Import:將tensorflow,onnx,pytorch等構(gòu)建的深度學(xué)習(xí)模型導(dǎo)入,轉(zhuǎn)化成TVM的中間層表示IR。Lower:將高層IR表示轉(zhuǎn)化成低階TIR表示。Codegen:內(nèi)存分配和硬件可執(zhí)
    發(fā)表于 01-26 09:23 ?13次下載
    <b class='flag-5'>TVM</b><b class='flag-5'>學(xué)習(xí)</b>(三)編譯流程

    TVM學(xué)習(xí)(二):算符融合

    算符融合將多個計算單元揉進一個計算核中進行,減少了中間數(shù)據(jù)的搬移,節(jié)省了計算時間。TVM中將計算算符分成四種: 1 injective。一一映射函數(shù),比如加法,點乘等。 2 reduction。輸入
    發(fā)表于 02-19 06:50 ?10次下載
    <b class='flag-5'>TVM</b><b class='flag-5'>學(xué)習(xí)</b>(二):算符融合

    FDTD學(xué)習(xí)總結(jié).pdf

    FDTD學(xué)習(xí)總結(jié).pdf
    發(fā)表于 01-17 11:28 ?0次下載

    使用TVM在android中進行Mobilenet SSD部署

    所謂TVM,按照正式說法:就是一種將深度學(xué)習(xí)工作負載部署到硬件的端到端IR(中間表示)堆棧。換一種說法,可以表述為一種把深度學(xué)習(xí)模型...
    發(fā)表于 02-07 12:07 ?0次下載
    使用<b class='flag-5'>TVM</b>在android中進行Mobilenet SSD部署