Lower操作完成從高級(jí)算子(relay)到低級(jí)算子(TOPI)的轉(zhuǎn)化。Lower開(kāi)始于以下代碼(src/relay/backend/graph_runtime_codegen.cc):
LoweredOutput Codegen(relay::Function func) { auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); storage_device_map_ = (*pf)(func); // First we convert all the parameters into input nodes. for (auto param : func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } heads_ = VisitExpr(func->body); std::ostringstream os; dmlc::JSONWriter writer(&os); GetJSON(&writer); LoweredOutput ret; ret.graph_json = os.str(); ret.params = params_; for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { ret.lowered_funcs.Set(kv.first, IRModule()); } auto& mod = ret.lowered_funcs[kv.first]; mod->Update(kv.second); ret.lowered_funcs.Set(kv.first, mod); } ret.external_mods = compile_engine_->LowerExternalFunctions(); return ret; }
在完成內(nèi)存申請(qǐng)優(yōu)化之后,VisitExpr對(duì)圖進(jìn)行遍歷并lower每個(gè)relay算子。我們來(lái)看CallNode節(jié)點(diǎn)的處理。主要代碼如下:
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); CCacheKey key = (*pf0)(func, target); CachedFunc ext_func = (*pf1)(compile_engine_, key); 這一步是當(dāng)存在外部compiler的時(shí)候,使用外部compiler進(jìn)行l(wèi)ower。CCacheKey將function和target打包到一起,可能是方便后邊compiler的調(diào)用。而lower函數(shù)會(huì)調(diào)用src/relay/backend/compile_engine.cc中CompileEngineImpl類中的LowerInternal函數(shù),在這個(gè)函數(shù)中實(shí)現(xiàn)了外部編譯器lower和內(nèi)部lower的代碼,如果是有外部compiler參與,其將function,target等打包成CCacheValue返回,等待后邊外部編譯器統(tǒng)一處理。 如果沒(méi)有外部編譯器,那么TVM將對(duì)relay算子轉(zhuǎn)換到TOPI庫(kù)中算子。 CachedFunc lowered_func = (*pf1)(compile_engine_, key); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);
同樣會(huì)調(diào)用LowerInternal函數(shù),首先建立schedule:
CachedFunc CreateSchedule(const Function& source_func, const Target& target) { return ScheduleGetter(target).Create(source_func); }
在Create函數(shù)中,首先將inputs都轉(zhuǎn)換成te的算子表示:
for (Var param : prim_func-> params) { Array inputs; if (const auto* ttype = param->checked_type().as< TensorTypeNode>()) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype-> shape), ttype->dtype); cache_node-> inputs.push_back(tensor); inputs.push_back(tensor); } else { // flatten tuple of tensor type. const auto* tuple_type = param-> type_as (); for (Type field : tuple_type-> fields) { const auto* ttype = field.as< TensorTypeNode> (); // TODO(@icemelon): Allow recursive tuple CHECK(ttype != nullptr); tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype-> shape), ttype-> dtype); cache_node-> inputs.push_back(tensor); inputs.push_back(tensor); } } memo_[param] = inputs; }
然后遍歷其它node來(lái)實(shí)現(xiàn)lower操作。
我們還是來(lái)看CallNode的訪問(wèn)。
Array VisitExpr_(const CallNode* call_node) final { static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); CHECK(flower_call) << "relay.backend.lower_call is not registered."; Array inputs; int count_tuple = 0; for (Expr arg : call_node->args) { if (arg->checked_type().as()) { ++count_tuple; } for (te::Tensor tensor : VisitExpr(arg)) { inputs.push_back(tensor); } } if (count_tuple) { CHECK_EQ(call_node-> args.size(), 1U) << "Only allow function with a single tuple input"; } CHECK(call_node->op.as>OpNode> ()) >> "Primitive function only allows call into primitive ops"; Op op = Downcast>Op>(call_node-> op); Array>te::Tensor> outputs; OpImplementation impl; // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef>Call>(call_node), inputs, target_); outputs = lowered_out->outputs;
這里lower操作會(huì)去調(diào)用python中注冊(cè)的lower_call函數(shù),這個(gè)函數(shù)位于python/tvm/relay/backend/compile_engine.py中。在這個(gè)函數(shù)中最主要的是select_implementation。
Select_implementation是去選擇relay算子的一個(gè)TOPI層級(jí)的實(shí)現(xiàn)方式。同一個(gè)relay算子在不同target上有不同實(shí)現(xiàn)方式,具體采用哪種方式要依據(jù)target的屬性。在select_implementation中首先通過(guò)gat_valid_implementation獲得所有已經(jīng)注冊(cè)的實(shí)現(xiàn)方式。
fstrategy = op.get_attr("FTVMStrategy") assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name with target: strategy = fstrategy(attrs, inputs, out_type, target) analyzer = tvm.arith.Analyzer() ret = [] for spec in strategy.specializations: if spec.condition: # check if all the clauses in the specialized condition are true flag = True for clause in spec.condition.clauses: clause = analyzer.canonical_simplify(clause) if isinstance(clause, tvm.tir.IntImm) and clause.value: continue flag = False break if flag: for impl in spec.implementations: ret.append(impl) else: for impl in spec.implementations: ret.append(impl) return ret
fstrategy指向的是op attr的"FTVMStrategy"對(duì)應(yīng)的函數(shù)。比如con2d注冊(cè)的策略有:
def conv2d_strategy(attrs, inputs, out_type, target): """conv2d generic strategy""" logger.warning("conv2d is not optimized for this platform.") strategy = _op.OpStrategy() data, kernel = inputs dilation = get_const_tuple(attrs.dilation) groups = attrs.groups layout = attrs.data_layout kernel_layout = attrs.kernel_layout (dilation_h, dilation_w) = dilation if dilation_h > 1 or dilation_w > 1: raise ValueError("dilation should be positive value") if groups == 1: if layout == "NCHW": assert kernel_layout == "OIHW" strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_conv2d_nchw), nam)
可見(jiàn)一個(gè)conv2d即使同一個(gè)target也會(huì)注冊(cè)不同的策略。Add_implementation將會(huì)把compute,schedule的具體函數(shù)注冊(cè)到strategy中。Strategy是一個(gè)包含了一個(gè)relay算子implementation方式的數(shù)據(jù)結(jié)構(gòu)。它包含了很多OpSpecialization,每個(gè)OpSpecialization中包含一些列OpImplementation,OpImplementation中就對(duì)應(yīng)著schedule和compute的具體方式,schedule是一個(gè)算子計(jì)算的排布,compute是對(duì)應(yīng)了TOPI庫(kù)算子。
獲得了所有有效implementation之后,會(huì)依據(jù)兩種方式選擇,一種是通過(guò)auto TVM來(lái)自動(dòng)化搜索最優(yōu)的實(shí)現(xiàn)方式,另外一種在不適用auto TVM工具情況下,會(huì)選擇plevel最大的implementation。選擇好了implementation之后,就調(diào)用src/relay/backend/compile_engine.cc中的LoweredOutput類建立一個(gè)實(shí)例??梢钥闯?,lower_call實(shí)現(xiàn)了將relay算子統(tǒng)一用更底層的的抽象進(jìn)行了表示。這種表示中包含了relay算子,以及這個(gè)算子的計(jì)算方式以及schedule信息。這樣就方便后邊對(duì)其進(jìn)行schedule優(yōu)化了。
然后將這些LoweredOutput進(jìn)行打包成CachedFuncNode。CachedFuncNode會(huì)作為后邊schedule優(yōu)化的入?yún)ⅰ?br />
審核編輯:湯梓紅
-
TVM
+關(guān)注
關(guān)注
0文章
19瀏覽量
3648 -
relay
+關(guān)注
關(guān)注
0文章
1瀏覽量
4441
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論