0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

如何針對(duì)涂鴉識(shí)別問題構(gòu)建基于RNN的識(shí)別器

Tensorflowers ? 來源:未知 ? 作者:胡薇 ? 2018-11-27 09:13 ? 次閱讀

Quick, Draw!是一款游戲;在這個(gè)游戲中,玩家要接受一項(xiàng)挑戰(zhàn):繪制幾個(gè)圖形,看看計(jì)算機(jī)能否識(shí)別玩家繪制的是什么。

Quick, Draw!的識(shí)別操作 由一個(gè)分類器執(zhí)行,它接收用戶輸入(用 (x, y) 中的點(diǎn)筆畫序列表示),然后識(shí)別用戶嘗試涂鴉的圖形所屬的類別。

在本教程中,我們將展示如何針對(duì)此問題構(gòu)建基于 RNN 的識(shí)別器。該模型將結(jié)合使用卷積層、LSTM 層和 softmax 輸出層對(duì)涂鴉進(jìn)行分類:

上圖顯示了我們將在本教程中構(gòu)建的模型的結(jié)構(gòu)。輸入為一個(gè)涂鴉,用 (x, y, n) 中的點(diǎn)筆畫序列表示,其中 n 表示點(diǎn)是否為新筆畫的第一個(gè)點(diǎn)。

然后,模型將應(yīng)用一系列一維卷積,接下來,會(huì)應(yīng)用 LSTM 層,并將所有 LSTM 步的輸出之和饋送到 softmax 層,以便根據(jù)我們已知的涂鴉類別來決定涂鴉的分類。

本教程使用的數(shù)據(jù)來自真實(shí)的Quick, Draw!游戲,這些數(shù)據(jù)是公開提供的。此數(shù)據(jù)集包含 5000 萬幅涂鴉,涵蓋 345 個(gè)類別。

運(yùn)行教程代碼

要嘗試本教程的代碼,請(qǐng)執(zhí)行以下操作:

安裝 TensorFlow(如果尚未安裝的話)

下載教程代碼

下載數(shù)據(jù)(TFRecord格式),然后解壓縮。如需詳細(xì)了解如何獲取原始 Quick, Draw!數(shù)據(jù)以及如何將數(shù)據(jù)轉(zhuǎn)換為TFRecord文件,請(qǐng)參閱下文

使用以下命令執(zhí)行教程代碼,以訓(xùn)練本教程中所述的基于 RNN 的模型。請(qǐng)務(wù)必調(diào)整路徑,使其指向第 3 步中下載的解壓縮數(shù)據(jù)

python train_model.py \ --training_data=rnn_tutorial_data/training.tfrecord-?????-of-????? \ --eval_data=rnn_tutorial_data/eval.tfrecord-?????-of-????? \ --classes_file=rnn_tutorial_data/training.tfrecord.classes

教程詳情

下載數(shù)據(jù)

我們將本教程中要使用的數(shù)據(jù)放在了包含TFExamples的TFRecord文件中。您可以從以下位置下載這些數(shù)據(jù):http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz(大約 1GB)。

或者,您也可以從 Google Cloud 下載ndjson格式的原始數(shù)據(jù),并將這些數(shù)據(jù)轉(zhuǎn)換為包含TFExamples的TFRecord文件,如下一部分中所述。

可選:下載完整的 QuickDraw 數(shù)據(jù)

完整的Quick, Draw!數(shù)據(jù)集可在 Google Cloud Storage 上找到,此數(shù)據(jù)集是按類別劃分的ndjson文件。您可以在 Cloud Console 中瀏覽文件列表。

要下載數(shù)據(jù),我們建議使用gsutil下載整個(gè)數(shù)據(jù)集。請(qǐng)注意,原始 .ndjson 文件需要下載約 22GB 的數(shù)據(jù)。

然后,使用以下命令檢查 gsutil 安裝是否成功以及您是否可以訪問數(shù)據(jù)存儲(chǔ)分區(qū):

gsutil ls -r "gs://quickdraw_dataset/full/simplified/*"

系統(tǒng)會(huì)輸出一長串文件,如下所示:

gs://quickdraw_dataset/full/simplified/The Eiffel Tower.ndjsongs://quickdraw_dataset/full/simplified/The Great Wall of China.ndjsongs://quickdraw_dataset/full/simplified/The Mona Lisa.ndjsongs://quickdraw_dataset/full/simplified/aircraft carrier.ndjson...

之后,創(chuàng)建一個(gè)文件夾并在其中下載數(shù)據(jù)集。

mkdir rnn_tutorial_datacd rnn_tutorial_datagsutil -m cp "gs://quickdraw_dataset/full/simplified/*" .

下載過程需要花費(fèi)一段時(shí)間,且下載的數(shù)據(jù)量略超 23GB。

可選:轉(zhuǎn)換數(shù)據(jù)

要將ndjson文件轉(zhuǎn)換為TFRecord文件(包含tf.train.Example樣本),請(qǐng)運(yùn)行以下命令。

python create_dataset.py --ndjson_path rnn_tutorial_data \ --output_path rnn_tutorial_data

此命令會(huì)將數(shù)據(jù)存儲(chǔ)在TFRecord文件的 10 個(gè)分片中,每個(gè)類別有 10000 項(xiàng)用于訓(xùn)練數(shù)據(jù),有 1000 項(xiàng)用于評(píng)估數(shù)據(jù)。

下文詳細(xì)說明了該轉(zhuǎn)換過程。

原始 QuickDraw 數(shù)據(jù)的格式為ndjson文件,其中每行包含一個(gè)如下所示的 JSON 對(duì)象:

{"word":"cat","countrycode":"VE","timestamp":"2017-03-02 23:25:10.07453 UTC","recognized":true,"key_id":"5201136883597312","drawing":[ [ [130,113,99,109,76,64,55,48,48,51,59,86,133,154,170,203,214,217,215,208,186,176,162,157,132], [72,40,27,79,82,88,100,120,134,152,165,184,189,186,179,152,131,114,100,89,76,0,31,65,70] ],[ [76,28,7], [136,128,128] ],[ [76,23,0], [160,164,175] ],[ [87,52,37], [175,191,204] ],[ [174,220,246,251], [134,132,136,139] ],[ [175,255], [147,168] ],[ [171,208,215], [164,198,210] ],[ [130,110,108,111,130,139,139,119], [129,134,137,144,148,144,136,130] ],[ [107,106], [96,113] ]]}

在構(gòu)建我們的分類器時(shí),我們只關(guān)注 “word” 和 “drawing” 字段。在解析 ndjson 文件時(shí),我們使用一個(gè)函數(shù)逐行處理它們,該函數(shù)可將drawing字段中的筆畫轉(zhuǎn)換為大小為[number of points, 3](包含連續(xù)點(diǎn)的差異)的張量。此函數(shù)還會(huì)以字符串形式返回類別名稱。

def parse_line(ndjson_line): """Parse an ndjson line and return ink (as np array) and classname.""" sample = json.loads(ndjson_line) class_name = sample["word"] inkarray = sample["drawing"] stroke_lengths = [len(stroke[0]) for stroke in inkarray] total_points = sum(stroke_lengths) np_ink = np.zeros((total_points, 3), dtype=np.float32) current_t = 0 for stroke in inkarray: for i in [0, 1]: np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i] current_t += len(stroke[0]) np_ink[current_t - 1, 2] = 1 # stroke_end # Preprocessing. # 1. Size normalization. lower = np.min(np_ink[:, 0:2], axis=0) upper = np.max(np_ink[:, 0:2], axis=0) scale = upper - lower scale[scale == 0] = 1 np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale # 2. Compute deltas. np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2] return np_ink, class_name

由于我們希望數(shù)據(jù)在寫入時(shí)進(jìn)行隨機(jī)處理,因此我們以隨機(jī)順序從每個(gè)類別文件中讀取數(shù)據(jù)并寫入隨機(jī)分片。

對(duì)于訓(xùn)練數(shù)據(jù),我們讀取每個(gè)類別的前 10000 項(xiàng);對(duì)于評(píng)估數(shù)據(jù),我們讀取每個(gè)類別接下來的 1000 項(xiàng)。

然后,將這些數(shù)據(jù)變形為[num_training_samples, max_length, 3]形狀的張量。接下來,我們用屏幕坐標(biāo)確定原始涂鴉的邊界框并標(biāo)準(zhǔn)化涂鴉的尺寸,使涂鴉具有單位高度。

最后,我們計(jì)算連續(xù)點(diǎn)之間的差異,并將它們存儲(chǔ)為VarLenFeature(位于tensorflow.Example中的ink鍵下)。另外,我們將class_index存儲(chǔ)為單一條目FixedLengthFeature,將ink的shape存儲(chǔ)為長度為 2 的FixedLengthFeature。

定義模型

要定義模型,我們需要?jiǎng)?chuàng)建一個(gè)新的Estimator。如需詳細(xì)了解 Estimator,建議您閱讀此教程。

要構(gòu)建模型,我們需要執(zhí)行以下操作:

將輸入調(diào)整回原始形狀,其中小批次通過填充達(dá)到其內(nèi)容的最大長度。除了 ink 數(shù)據(jù)之外,我們還擁有每個(gè)樣本的長度和目標(biāo)類別。這可通過函數(shù)_get_input_tensors實(shí)現(xiàn)

將輸入傳遞給_add_conv_layers中的一系列卷積層

將卷積的輸出傳遞到_add_rnn_layers中的一系列雙向 LSTM 層。最后,將每個(gè)時(shí)間步的輸出相加,針對(duì)輸入生成一個(gè)固定長度的緊湊嵌入

在_add_fc_layers中使用 softmax 層對(duì)此嵌入進(jìn)行分類

代碼如下所示:

inks, lengths, targets = _get_input_tensors(features, targets)convolved = _add_conv_layers(inks)final_state = _add_rnn_layers(convolved, lengths)logits =_add_fc_layers(final_state)

_get_input_tensors

要獲得輸入特征,我們先從特征字典獲得形狀,然后創(chuàng)建大小為[batch_size](包含輸入序列的長度)的一維張量。ink 作為稀疏張量存儲(chǔ)在特征字典中,我們將其轉(zhuǎn)換為密集張量,然后變形為[batch_size, ?, 3]。最后,如果傳入目標(biāo),我們需要確保它們存儲(chǔ)為大小為[batch_size]的一維張量。

代碼如下所示:

shapes = features["shape"]lengths = tf.squeeze( tf.slice(shapes, begin=[0, 0], size=[params["batch_size"], 1]))inks = tf.reshape( tf.sparse_tensor_to_dense(features["ink"]), [params["batch_size"], -1, 3])if targets is not None: targets = tf.squeeze(targets)

_add_conv_layers

您可以通過params字典中的參數(shù)num_conv和conv_len配置所需的卷積層數(shù)量和過濾器長度。

輸入是一個(gè)每個(gè)點(diǎn)維數(shù)都是 3 的序列。我們將使用一維卷積,將 3 個(gè)輸入特征視為通道。這意味著輸入為[batch_size, length, 3]張量,而輸出為[batch_size, length, number_of_filters]張量。

convolved = inksfor i in range(len(params.num_conv)): convolved_input = convolved if params.batch_norm: convolved_input = tf.layers.batch_normalization( convolved_input, training=(mode == tf.estimator.ModeKeys.TRAIN)) # Add dropout layer if enabled and not first convolution layer. if i > 0 and params.dropout: convolved_input = tf.layers.dropout( convolved_input, rate=params.dropout, training=(mode == tf.estimator.ModeKeys.TRAIN)) convolved = tf.layers.conv1d( convolved_input, filters=params.num_conv[i], kernel_size=params.conv_len[i], activation=None, strides=1, padding="same", name="conv1d_%d" % i)return convolved, lengths

_add_rnn_layers

我們將卷積的輸出傳遞給雙向 LSTM 層,對(duì)此我們使用 contrib 的輔助函數(shù)。

outputs, _, _ = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw=[cell(params.num_nodes) for _ in range(params.num_layers)], cells_bw=[cell(params.num_nodes) for _ in range(params.num_layers)], inputs=convolved, sequence_length=lengths, dtype=tf.float32, scope="rnn_classification")

請(qǐng)參閱代碼以了解詳情以及如何使用CUDA加速實(shí)現(xiàn)。

要?jiǎng)?chuàng)建一個(gè)固定長度的緊湊嵌入,我們需要將 LSTM 的輸出相加。我們首先將其中的序列不含數(shù)據(jù)的批次區(qū)域設(shè)為 0。

mask = tf.tile( tf.expand_dims(tf.sequence_mask(lengths, tf.shape(outputs)[1]), 2), [1, 1, tf.shape(outputs)[2]])zero_outside = tf.where(mask, outputs, tf.zeros_like(outputs))outputs = tf.reduce_sum(zero_outside, axis=1)

_add_fc_layers

將輸入的嵌入傳遞至全連接層,之后將此層用作 softmax 層。

tf.layers.dense(final_state, params.num_classes)

損失、預(yù)測(cè)和優(yōu)化器

最后,我們需要添加一個(gè)損失函數(shù)、一個(gè)訓(xùn)練操作和預(yù)測(cè)來創(chuàng)建ModelFn:

cross_entropy = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=targets, logits=logits))# Add the optimizer.train_op = tf.contrib.layers.optimize_loss( loss=cross_entropy, global_step=tf.train.get_global_step(), learning_rate=params.learning_rate, optimizer="Adam", # some gradient clipping stabilizes training in the beginning. clip_gradients=params.gradient_clipping_norm, summaries=["learning_rate", "loss", "gradients", "gradient_norm"])predictions = tf.argmax(logits, axis=1)return model_fn_lib.ModelFnOps( mode=mode, predictions={"logits": logits, "predictions": predictions}, loss=cross_entropy, train_op=train_op, eval_metric_ops={"accuracy": tf.metrics.accuracy(targets, predictions)})

訓(xùn)練和評(píng)估模型

要訓(xùn)練和評(píng)估模型,我們可以借助EstimatorAPI 的功能,并使用ExperimentAPI 輕松運(yùn)行訓(xùn)練和評(píng)估操作:

estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=output_dir, config=config, params=model_params) # Train the model. tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=get_input_fn( mode=tf.contrib.learn.ModeKeys.TRAIN, tfrecord_pattern=FLAGS.training_data, batch_size=FLAGS.batch_size), train_steps=FLAGS.steps, eval_input_fn=get_input_fn( mode=tf.contrib.learn.ModeKeys.EVAL, tfrecord_pattern=FLAGS.eval_data, batch_size=FLAGS.batch_size), min_eval_frequency=1000)

請(qǐng)注意,本教程只是用一個(gè)相對(duì)較小的數(shù)據(jù)集進(jìn)行簡單演示,目的是讓您熟悉遞歸神經(jīng)網(wǎng)絡(luò)和 Estimator 的 API。如果在大型數(shù)據(jù)集上嘗試,這些模型可能會(huì)更強(qiáng)大。

當(dāng)模型完成 100 萬個(gè)訓(xùn)練步后,分?jǐn)?shù)最高的候選項(xiàng)的準(zhǔn)確率預(yù)計(jì)會(huì)達(dá)到 70% 左右。請(qǐng)注意,這種程度的準(zhǔn)確率足以構(gòu)建 Quick, Draw! 游戲,由于該游戲的動(dòng)態(tài)特性,用戶可以在系統(tǒng)準(zhǔn)備好識(shí)別之前調(diào)整涂鴉。此外,如果目標(biāo)類別顯示的分?jǐn)?shù)高于固定閾值,該游戲不會(huì)僅使用分?jǐn)?shù)最高的候選項(xiàng),而且會(huì)將某個(gè)涂鴉視為正確的涂鴉。

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴

原文標(biāo)題:Quick, Draw! 涂鴉分類遞歸神經(jīng)網(wǎng)絡(luò)

文章出處:【微信號(hào):tensorflowers,微信公眾號(hào):Tensorflowers】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    輪掃按鍵識(shí)別問

    大俠出來相求,每一個(gè)按鍵都可以唯一被識(shí)別嗎?機(jī)理是什么?
    發(fā)表于 07-27 16:58

    2812識(shí)別問

    用2812+cpld采集圖像然后再用2812識(shí)別,這個(gè)圖像識(shí)別很簡單,只是識(shí)別圖像中有幾條條紋??梢宰鰡??求解。
    發(fā)表于 04-16 11:08

    語音識(shí)別問

    各位大神,我想完成用SPCE061A來實(shí)現(xiàn)非特定人的語音識(shí)別技術(shù),并能夠使得發(fā)出的命令能在LCD上顯示,不知有沒有能夠指導(dǎo)一下的,大概的框架和模塊,拜托各位了。。。
    發(fā)表于 01-06 22:47

    請(qǐng)教 LD3320 語音識(shí)別問

    在X寶買了一塊LD3320 模塊,用的是并行通訊,讀寫寄存都正常,啟動(dòng)識(shí)別后有中斷, 識(shí)別結(jié)果寄存(0xBA)一直是0 . 是什么問題呀? 有沒人做成功的.分享下經(jīng)驗(yàn)!!! 謝謝
    發(fā)表于 03-28 13:43

    stm32 ov***的黑線識(shí)別問

    求一個(gè)基于stm32f103和ov***的黑線識(shí)別程序,串口顯示,液晶顯示都行。一直有問題出不來啊
    發(fā)表于 03-13 15:47

    OCR識(shí)別問

    我用圖像助手訓(xùn)練的時(shí)候能識(shí)別數(shù)字,但是訓(xùn)練完后還是不能識(shí)別?為什么~~求大神告知一下下
    發(fā)表于 12-07 11:21

    DHCP識(shí)別問題如何解決

    我有一些DHCP服務(wù)不使用和諧網(wǎng)絡(luò)棧來識(shí)別單元的問題??磥恚绻疫B接一臺(tái)筆記本電腦到服務(wù),它是公認(rèn)的罰款。是否有少量信息用于識(shí)別?我注意到,筆記本電腦與和聲棧相比,發(fā)送了很多東西
    發(fā)表于 05-11 13:21

    如何解決網(wǎng)絡(luò)無法識(shí)別問

    網(wǎng)絡(luò)問題分類網(wǎng)絡(luò)無法識(shí)別問題還是比較好排查,但是如果涉及到網(wǎng)絡(luò)丟包牽扯的環(huán)節(jié)太多了比如交換芯片是否異常,對(duì)方的工作模式是否正常、網(wǎng)絡(luò)隔離變壓是否正常、CPU占用率、設(shè)備中斷影響先排除網(wǎng)絡(luò)環(huán)境和對(duì)方設(shè)備、在確認(rèn)設(shè)備問題比如phy的時(shí)鐘是否重疊、phy的流控是否開啟等等..
    發(fā)表于 12-23 06:08

    離線語音識(shí)別和控制的工作原理及應(yīng)用

    :   1.信號(hào)采集   離線語音識(shí)別系統(tǒng)的第一步是信號(hào)采集。聲音信號(hào)通過麥克風(fēng)(傳感)以電信號(hào)的形式被捕捉到,這是后續(xù)處理的基礎(chǔ)。   2.預(yù)處理   預(yù)處理階段包括去除噪聲、回聲消除、降噪等處理
    發(fā)表于 11-07 18:01

    USB硬盤的系統(tǒng)識(shí)別問

      1、 如果系統(tǒng)裝的是win98,如不能被正確識(shí)別(即使安裝了USB2.0通用驅(qū)動(dòng)也識(shí)別不了),這種情況下要檢查一下你的移動(dòng)硬盤是否供電不足,如果供電不足就會(huì)出現(xiàn)“咳咳”的聲
    發(fā)表于 08-31 17:19 ?1017次閱讀

    貼片電容壞了怎么識(shí)別

    貼片電容如何識(shí)別?識(shí)別方法有哪些?,最近網(wǎng)上出現(xiàn)很多的貼片電容識(shí)別問題,很多人因?yàn)閷?duì)貼片電容的容值識(shí)別不了解,導(dǎo)致失誤的機(jī)率提高。下面小編分享一下貼片電容的
    發(fā)表于 05-10 14:48 ?1.2w次閱讀

    USB智能識(shí)別IC可解決傳統(tǒng)USB口的識(shí)別問

    USB智能識(shí)別IC(PL515,PL513),適用于車充,充電器,移動(dòng)電源等 USB口輸出供電方案。 USB智能識(shí)別IC,是用來解決傳統(tǒng)USB口的識(shí)別電阻,識(shí)別電阻做的
    的頭像 發(fā)表于 10-15 14:20 ?6364次閱讀
    USB智能<b class='flag-5'>識(shí)別</b>IC可解決傳統(tǒng)USB口的<b class='flag-5'>識(shí)別問</b>題

    HID_CDC復(fù)合設(shè)備在WIN10的識(shí)別問

    HID_CDC復(fù)合設(shè)備在WIN10的識(shí)別問題(電源技術(shù)發(fā)展綜述)-本文以STM32F405為例,詳細(xì)說明上HID_CDC復(fù)合設(shè)備在WIN10的識(shí)別問題。
    發(fā)表于 08-04 18:23 ?20次下載
    HID_CDC復(fù)合設(shè)備在WIN10的<b class='flag-5'>識(shí)別問</b>題

    STM32F0的USART波特率自動(dòng)識(shí)別問

    電子發(fā)燒友網(wǎng)站提供《STM32F0的USART波特率自動(dòng)識(shí)別問題.pdf》資料免費(fèi)下載
    發(fā)表于 08-01 11:00 ?2次下載
    STM32F0的USART波特率自動(dòng)<b class='flag-5'>識(shí)別問</b>題

    Purple Pi OH固件的芯片信息識(shí)別問題說明

    開源鴻蒙硬件方案領(lǐng)跑者觸覺智能本文適用于在PurplePiOH固件的芯片信息識(shí)別問題說明。觸覺智能的PurplePiOH鴻蒙開源主板,是華為Laval官方社區(qū)主薦的一款鴻蒙開發(fā)主板。該主板主要針對(duì)
    的頭像 發(fā)表于 06-26 08:32 ?152次閱讀
    Purple Pi OH固件的芯片信息<b class='flag-5'>識(shí)別問</b>題說明