0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

TVM學(xué)習(xí)之從relay到TOPI

ytrwv ? 來(lái)源:ytrwv ? 作者:ytrwv ? 2022-08-02 10:16 ? 次閱讀

Lower操作完成從高級(jí)算子(relay)到低級(jí)算子(TOPI)的轉(zhuǎn)化。Lower開(kāi)始于以下代碼(src/relay/backend/graph_runtime_codegen.cc):

 LoweredOutput Codegen(relay::Function func) {
    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
    storage_device_map_ = (*pf)(func);
    // First we convert all the parameters into input nodes.
    for (auto param : func->params) {
      auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
      var_map_[param.get()] = AddNode(node_ptr, param);
    }
    heads_ = VisitExpr(func->body);
    std::ostringstream os;
    dmlc::JSONWriter writer(&os);
    GetJSON(&writer);
    LoweredOutput ret;
    ret.graph_json = os.str();
    ret.params = params_;


    for (auto& kv : lowered_funcs_) {
      if (ret.lowered_funcs.count(kv.first) == 0) {
        ret.lowered_funcs.Set(kv.first, IRModule());
      }
      auto& mod = ret.lowered_funcs[kv.first];
      mod->Update(kv.second);
      ret.lowered_funcs.Set(kv.first, mod);
    }
    ret.external_mods = compile_engine_->LowerExternalFunctions();
    return ret;
  }

在完成內(nèi)存申請(qǐng)優(yōu)化之后,VisitExpr對(duì)圖進(jìn)行遍歷并lower每個(gè)relay算子。我們來(lái)看CallNode節(jié)點(diǎn)的處理。主要代碼如下:

auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
    Target target;
    // Handle external function
    if (func->GetAttr(attr::kCompiler).defined()) {
      target = tvm::target::ext_dev();
      CCacheKey key = (*pf0)(func, target);
      CachedFunc ext_func = (*pf1)(compile_engine_, key);
這一步是當(dāng)存在外部compiler的時(shí)候,使用外部compiler進(jìn)行l(wèi)ower。CCacheKey將function和target打包到一起,可能是方便后邊compiler的調(diào)用。而lower函數(shù)會(huì)調(diào)用src/relay/backend/compile_engine.cc中CompileEngineImpl類中的LowerInternal函數(shù),在這個(gè)函數(shù)中實(shí)現(xiàn)了外部編譯器lower和內(nèi)部lower的代碼,如果是有外部compiler參與,其將function,target等打包成CCacheValue返回,等待后邊外部編譯器統(tǒng)一處理。
如果沒(méi)有外部編譯器,那么TVM將對(duì)relay算子轉(zhuǎn)換到TOPI庫(kù)中算子。
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
    if (!lowered_funcs_.count(target->str())) {
      lowered_funcs_[target->str()] = IRModule();
    }
    lowered_funcs_[target->str()]->Update(lowered_func->funcs);
return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);

同樣會(huì)調(diào)用LowerInternal函數(shù),首先建立schedule:

CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
    return ScheduleGetter(target).Create(source_func);
  }

在Create函數(shù)中,首先將inputs都轉(zhuǎn)換成te的算子表示:

for (Var param : prim_func-> params) {
      Array  inputs;
      if (const auto* ttype = param->checked_type().as< TensorTypeNode>()) {
        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype-> shape), ttype->dtype);
        cache_node-> inputs.push_back(tensor);
        inputs.push_back(tensor);
      } else {
        // flatten tuple of tensor type.
        const auto* tuple_type = param-> type_as ();
        for (Type field : tuple_type-> fields) {
          const auto* ttype = field.as< TensorTypeNode> ();
          // TODO(@icemelon): Allow recursive tuple
          CHECK(ttype != nullptr);
          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype-> shape), ttype-> dtype);
          cache_node-> inputs.push_back(tensor);
          inputs.push_back(tensor);
        }
      }
      memo_[param] = inputs;
}

然后遍歷其它node來(lái)實(shí)現(xiàn)lower操作。

我們還是來(lái)看CallNode的訪問(wèn)。

Array VisitExpr_(const CallNode* call_node) final {
    static auto fpattern = Op::GetAttrMap("TOpPattern");
    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
    CHECK(flower_call) << "relay.backend.lower_call is not registered.";


    Array inputs;
    int count_tuple = 0;
    for (Expr arg : call_node->args) {
      if (arg->checked_type().as()) {
        ++count_tuple;
      }
      for (te::Tensor tensor : VisitExpr(arg)) {
        inputs.push_back(tensor);
      }
    }
    if (count_tuple) {
      CHECK_EQ(call_node-> args.size(), 1U) << "Only allow function with a single tuple input";
    }


    CHECK(call_node->op.as>OpNode> ()) >> "Primitive function only allows call into primitive ops";
    Op op = Downcast>Op>(call_node-> op);


    Array>te::Tensor> outputs;
    OpImplementation impl;
    // Skip fcompute for device copy operators as it is not registered.
    if (op == device_copy_op_) {
      const auto* copy_input = inputs[0].operator->();
      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
    } else {
      LoweredOutput lowered_out = (*flower_call)(GetRef>Call>(call_node), inputs, target_);
      outputs = lowered_out->outputs;

這里lower操作會(huì)去調(diào)用python中注冊(cè)的lower_call函數(shù),這個(gè)函數(shù)位于python/tvm/relay/backend/compile_engine.py中。在這個(gè)函數(shù)中最主要的是select_implementation。

Select_implementation是去選擇relay算子的一個(gè)TOPI層級(jí)的實(shí)現(xiàn)方式。同一個(gè)relay算子在不同target上有不同實(shí)現(xiàn)方式,具體采用哪種方式要依據(jù)target的屬性。在select_implementation中首先通過(guò)gat_valid_implementation獲得所有已經(jīng)注冊(cè)的實(shí)現(xiàn)方式。

fstrategy = op.get_attr("FTVMStrategy")
    assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
    with target:
        strategy = fstrategy(attrs, inputs, out_type, target)
    analyzer = tvm.arith.Analyzer()
    ret = []
    for spec in strategy.specializations:
        if spec.condition:
            # check if all the clauses in the specialized condition are true
            flag = True
            for clause in spec.condition.clauses:
                clause = analyzer.canonical_simplify(clause)
                if isinstance(clause, tvm.tir.IntImm) and clause.value:
                    continue
                flag = False
                break
            if flag:
                for impl in spec.implementations:
                    ret.append(impl)
        else:
            for impl in spec.implementations:
                ret.append(impl)
return ret

fstrategy指向的是op attr的"FTVMStrategy"對(duì)應(yīng)的函數(shù)。比如con2d注冊(cè)的策略有:

def conv2d_strategy(attrs, inputs, out_type, target):
    """conv2d generic strategy"""
    logger.warning("conv2d is not optimized for this platform.")
    strategy = _op.OpStrategy()
    data, kernel = inputs
    dilation = get_const_tuple(attrs.dilation)
    groups = attrs.groups
    layout = attrs.data_layout
    kernel_layout = attrs.kernel_layout
    (dilation_h, dilation_w) = dilation
    if dilation_h > 1 or dilation_w > 1:
        raise ValueError("dilation should be positive value")


    if groups == 1:
        if layout == "NCHW":
            assert kernel_layout == "OIHW"
            strategy.add_implementation(
                wrap_compute_conv2d(topi.nn.conv2d_nchw),
                wrap_topi_schedule(topi.generic.schedule_conv2d_nchw),
                nam)

可見(jiàn)一個(gè)conv2d即使同一個(gè)target也會(huì)注冊(cè)不同的策略。Add_implementation將會(huì)把compute,schedule的具體函數(shù)注冊(cè)到strategy中。Strategy是一個(gè)包含了一個(gè)relay算子implementation方式的數(shù)據(jù)結(jié)構(gòu)。它包含了很多OpSpecialization,每個(gè)OpSpecialization中包含一些列OpImplementation,OpImplementation中就對(duì)應(yīng)著schedule和compute的具體方式,schedule是一個(gè)算子計(jì)算的排布,compute是對(duì)應(yīng)了TOPI庫(kù)算子。

獲得了所有有效implementation之后,會(huì)依據(jù)兩種方式選擇,一種是通過(guò)auto TVM來(lái)自動(dòng)化搜索最優(yōu)的實(shí)現(xiàn)方式,另外一種在不適用auto TVM工具情況下,會(huì)選擇plevel最大的implementation。選擇好了implementation之后,就調(diào)用src/relay/backend/compile_engine.cc中的LoweredOutput類建立一個(gè)實(shí)例??梢钥闯?,lower_call實(shí)現(xiàn)了將relay算子統(tǒng)一用更底層的的抽象進(jìn)行了表示。這種表示中包含了relay算子,以及這個(gè)算子的計(jì)算方式以及schedule信息。這樣就方便后邊對(duì)其進(jìn)行schedule優(yōu)化了。

然后將這些LoweredOutput進(jìn)行打包成CachedFuncNode。CachedFuncNode會(huì)作為后邊schedule優(yōu)化的入?yún)ⅰ?br />
審核編輯:湯梓紅

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • TVM
    TVM
    +關(guān)注

    關(guān)注

    0

    文章

    19

    瀏覽量

    3648
  • relay
    +關(guān)注

    關(guān)注

    0

    文章

    1

    瀏覽量

    4441
收藏 人收藏

    評(píng)論

    相關(guān)推薦

    TVM主要的編譯過(guò)程解析

    低階TIR表示?! odegen:內(nèi)存分配和硬件可執(zhí)行程序生成?! D導(dǎo)入  通過(guò)一個(gè)tensorflow的reception網(wǎng)絡(luò)來(lái)熟悉編譯過(guò)程,其它深度學(xué)習(xí)框架也具有類似過(guò)程。TVM官網(wǎng)可以
    發(fā)表于 01-07 16:59

    TVM整體結(jié)構(gòu),TVM代碼的基本構(gòu)成

    ,編譯器包括了前端和后端,前端主要實(shí)現(xiàn)從tensorflow等深度學(xué)習(xí)框架描述的網(wǎng)絡(luò)結(jié)構(gòu)形式新表示的轉(zhuǎn)化,后端完成編譯器中間表示硬件可執(zhí)行程序的轉(zhuǎn)化。前端對(duì)硬件應(yīng)該是透明的,它的主要挑戰(zhàn)在于如何設(shè)計(jì)出
    發(fā)表于 01-07 17:21

    TVM中將計(jì)算算符有哪幾種

    TVM中將計(jì)算算符分成四種
    發(fā)表于 01-26 06:34

    TVM的編譯流程

    TVM主要的編譯過(guò)程
    發(fā)表于 02-23 07:43

    什么是frame relay,frame relay概念

    什么是frame relay,frame relay概念 物理拓?fù)?交換機(jī),存取鏈路,Trunks,CSU/DSU
    發(fā)表于 06-11 09:21 ?3138次閱讀
    什么是frame <b class='flag-5'>relay</b>,frame <b class='flag-5'>relay</b>概念

    《HTML 5 入門精通》-中文學(xué)習(xí)教程

    《HTML 5 入門精通》-中文學(xué)習(xí)教程.pdf 《HTML 5 入門精通》-中文學(xué)習(xí)
    發(fā)表于 11-02 17:45 ?0次下載

    TVM用于移動(dòng)端常見(jiàn)的ARM GPU,提高移動(dòng)設(shè)備對(duì)深度學(xué)習(xí)的支持能力

    的壓力。 TVM是一個(gè)端端的IR堆棧,它可以解決學(xué)習(xí)過(guò)程中的資源分配問(wèn)題,從而輕松實(shí)現(xiàn)硬件優(yōu)化。在這篇文章中,我們將展示如何用TVM/NNVM為ARM Mali GPU生成高效
    的頭像 發(fā)表于 01-18 13:38 ?1.1w次閱讀

    什么是波場(chǎng)虛擬機(jī)TVM

    TVM與現(xiàn)有的開(kāi)發(fā)生態(tài)系統(tǒng)無(wú)縫連接,并支持 DPoS。 TVM最初與 EVM 環(huán)境兼容,因此開(kāi)發(fā)人員可以使用Solidity和其他語(yǔ)言在 Remix 環(huán)境中開(kāi)發(fā),調(diào)試和編譯智能合約,而不是學(xué)習(xí)
    發(fā)表于 05-15 09:46 ?3209次閱讀
    什么是波場(chǎng)虛擬機(jī)<b class='flag-5'>TVM</b>

    TVM的編譯流程是什么

    TVM主要的編譯過(guò)程如下圖:Import:將tensorflow,onnx,pytorch等構(gòu)建的深度學(xué)習(xí)模型導(dǎo)入,轉(zhuǎn)化成TVM的中間層表示IR。Lower:將高層IR表示轉(zhuǎn)化成低階TIR表示。Codegen:內(nèi)存分配和硬件可執(zhí)
    的頭像 發(fā)表于 02-08 14:51 ?1654次閱讀
    <b class='flag-5'>TVM</b>的編譯流程是什么

    TVM學(xué)習(xí)(三)編譯流程

    TVM主要的編譯過(guò)程如下圖:Import:將tensorflow,onnx,pytorch等構(gòu)建的深度學(xué)習(xí)模型導(dǎo)入,轉(zhuǎn)化成TVM的中間層表示IR。Lower:將高層IR表示轉(zhuǎn)化成低階TIR表示。Codegen:內(nèi)存分配和硬件可執(zhí)
    發(fā)表于 01-26 09:23 ?13次下載
    <b class='flag-5'>TVM</b><b class='flag-5'>學(xué)習(xí)</b>(三)編譯流程

    TVM學(xué)習(xí)(四)codegen

    接著上一章繼續(xù)深入代碼,在BuildRelay中會(huì)調(diào)用Codegen函數(shù)。這個(gè)函數(shù)實(shí)現(xiàn)在src/relay/backend/graph_runtime_codegen.cc中。Codegen實(shí)現(xiàn)了內(nèi)存的分配,IR節(jié)點(diǎn)到TIR節(jié)點(diǎn)的轉(zhuǎn)換,tir圖節(jié)點(diǎn)的一個(gè)調(diào)度優(yōu)化。
    發(fā)表于 01-27 06:43 ?8次下載
    <b class='flag-5'>TVM</b><b class='flag-5'>學(xué)習(xí)</b>(四)codegen

    TVM學(xué)習(xí)(二):算符融合

    算符融合將多個(gè)計(jì)算單元揉進(jìn)一個(gè)計(jì)算核中進(jìn)行,減少了中間數(shù)據(jù)的搬移,節(jié)省了計(jì)算時(shí)間。TVM中將計(jì)算算符分成四種: 1 injective。一一映射函數(shù),比如加法,點(diǎn)乘等。 2 reduction。輸入
    發(fā)表于 02-19 06:50 ?10次下載
    <b class='flag-5'>TVM</b><b class='flag-5'>學(xué)習(xí)</b>(二):算符融合

    使用TVM在android中進(jìn)行Mobilenet SSD部署

    所謂TVM,按照正式說(shuō)法:就是一種將深度學(xué)習(xí)工作負(fù)載部署硬件的端端IR(中間表示)堆棧。換一種說(shuō)法,可以表述為一種把深度學(xué)習(xí)模型...
    發(fā)表于 02-07 12:07 ?0次下載
    使用<b class='flag-5'>TVM</b>在android中進(jìn)行Mobilenet SSD部署

    TVM學(xué)習(xí)(八)pass總結(jié)

    Pass是TVM中基于relay IR進(jìn)行的優(yōu)化,目的是去除冗余算子,進(jìn)行硬件友好的算子轉(zhuǎn)換,最終能夠提高硬件運(yùn)行效率。由tensorflow等深度學(xué)習(xí)框架生成的圖機(jī)構(gòu)中,含有很多可以優(yōu)化的算子
    的頭像 發(fā)表于 08-02 09:43 ?1879次閱讀
    <b class='flag-5'>TVM</b><b class='flag-5'>學(xué)習(xí)</b>(八)pass總結(jié)

    PyTorch教程7.1全連接層卷積

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程7.1全連接層卷積.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 11:50 ?0次下載
    PyTorch教程7.1<b class='flag-5'>之</b><b class='flag-5'>從</b>全連接層<b class='flag-5'>到</b>卷積