學(xué)習(xí)有關(guān)圖神經(jīng)網(wǎng)絡(luò)的所有知識(shí),包括 GNN 是什么,不同類(lèi)型的圖神經(jīng)網(wǎng)絡(luò),以及它們的用途。此外,了解如何使用 PyTorch 構(gòu)建圖神經(jīng)網(wǎng)絡(luò)。
為適合中文閱讀習(xí)慣,閱讀更有代入感,原文翻譯后有刪改。轉(zhuǎn)載請(qǐng)注明原文出處,并說(shuō)明由我得學(xué)城翻譯整理。
? ? ? ? Abid Ali Awan| 作者?
1. 什么是圖?
圖是一種包含節(jié)點(diǎn)和邊的數(shù)據(jù)結(jié)構(gòu)。一個(gè)節(jié)點(diǎn)可以是一個(gè)人、地方或物體,邊定義了節(jié)點(diǎn)之間的關(guān)系。邊可以是有向的,也可以是無(wú)向的,基于方向性依賴關(guān)系。
在下面的示例中,藍(lán)色圓圈是節(jié)點(diǎn),箭頭是邊。邊的方向定義了兩個(gè)節(jié)點(diǎn)之間的依賴關(guān)系。
讓我們了解一下復(fù)雜的圖數(shù)據(jù)集:爵士音樂(lè)家網(wǎng)絡(luò)。它包含198個(gè)節(jié)點(diǎn)和2742條邊。
爵士音樂(lè)家網(wǎng)絡(luò)https://datarepository.wolframcloud.com/resources/Jazz-Musicians-Network
在下面的社區(qū)圖中,不同顏色的節(jié)點(diǎn)代表爵士音樂(lè)家的各種社區(qū),邊連接著它們。存在一種協(xié)作網(wǎng)絡(luò),其中單個(gè)音樂(lè)家在社區(qū)內(nèi)外都有關(guān)系。
爵士音樂(lè)家網(wǎng)絡(luò)的社區(qū)圖
圖在處理具有關(guān)系和相互作用的復(fù)雜問(wèn)題方面非常出色。它們?cè)谀J阶R(shí)別、社交網(wǎng)絡(luò)分析、推薦系統(tǒng)和語(yǔ)義分析中得到應(yīng)用。創(chuàng)建基于圖的解決方案是一個(gè)全新的領(lǐng)域,為復(fù)雜且相互關(guān)聯(lián)的數(shù)據(jù)集提供了豐富的見(jiàn)解。
2. 使用 NetworkX 創(chuàng)建圖
在本節(jié)中,我們將學(xué)習(xí)使用NetworkX創(chuàng)建圖。
下面的代碼受到 Daniel Holmberg 在 Python 中的圖神經(jīng)網(wǎng)絡(luò)博客的影響。
創(chuàng)建 networkx 的DiGraph對(duì)象“H”
添加包含不同標(biāo)簽、顏色和大小的節(jié)點(diǎn)
添加邊以創(chuàng)建兩個(gè)節(jié)點(diǎn)之間的關(guān)系。例如,“(0,1)”表示 0 對(duì) 1 有方向性依賴。我們將通過(guò)添加“(1,0)”來(lái)創(chuàng)建雙向關(guān)系
以列表形式提取顏色和大小
使用 networkx 的draw函數(shù)繪制圖
import?networkx?as?nx H?=?nx.DiGraph() #adding?nodes H.add_nodes_from([ ??(0,?{"color":?"blue",?"size":?250}), ??(1,?{"color":?"yellow",?"size":?400}), ??(2,?{"color":?"orange",?"size":?150}), ??(3,?{"color":?"red",?"size":?600}) ]) #adding?edges H.add_edges_from([ ??(0,?1), ??(1,?2), ??(1,?0), ??(1,?3), ??(2,?3), ??(3,0) ]) node_colors?=?nx.get_node_attributes(H,?"color").values() colors?=?list(node_colors) node_sizes?=?nx.get_node_attributes(H,?"size").values() sizes?=?list(node_sizes) #?Plotting?Graph nx.draw(H,?with_labels=True,?node_color=colors,?node_size=sizes)
?
?
在下一步中,我們將使用to_undirected()函數(shù)將數(shù)據(jù)結(jié)構(gòu)從有向圖轉(zhuǎn)換為無(wú)向圖。
#?轉(zhuǎn)換為無(wú)向圖 G?=?H.to_undirected() nx.draw(G,?with_labels=True,?node_color=colors,?node_size=sizes)
3. 為什么分析圖很難?
基于圖的數(shù)據(jù)結(jié)構(gòu)存在一些缺點(diǎn),數(shù)據(jù)科學(xué)家在開(kāi)發(fā)基于圖的解決方案之前必須了解這些缺點(diǎn)。
圖存在于非歐幾里得空間中。它不在 2D 或 3D 空間中存在,這使得解釋數(shù)據(jù)變得更加困難。為了在 2D 空間中可視化結(jié)構(gòu),您必須使用各種降維工具。
圖是動(dòng)態(tài)的;它們沒(méi)有固定的形式??梢源嬖趦蓚€(gè)在視覺(jué)上不同的圖,但它們可能具有相似的鄰接矩陣表示。這使得使用傳統(tǒng)的統(tǒng)計(jì)工具來(lái)分析數(shù)據(jù)變得困難。
對(duì)于人類(lèi)解讀來(lái)說(shuō),圖的規(guī)模和維度會(huì)增加圖的復(fù)雜性。具有多個(gè)節(jié)點(diǎn)和數(shù)千條邊的密集結(jié)構(gòu)更難理解和提取洞察。
4. 什么是圖神經(jīng)網(wǎng)絡(luò)(GNN)?
圖神經(jīng)網(wǎng)絡(luò)是一種特殊類(lèi)型的神經(jīng)網(wǎng)絡(luò),能夠處理圖數(shù)據(jù)結(jié)構(gòu)。它們受到卷積神經(jīng)網(wǎng)絡(luò)(CNNs)和圖嵌入的很大影響。GNNs 用于預(yù)測(cè)節(jié)點(diǎn)、邊和基于圖的任務(wù)。
CNNs 用于圖像分類(lèi)。類(lèi)似地,GNNs 應(yīng)用于圖結(jié)構(gòu)(像素網(wǎng)格)以預(yù)測(cè)一個(gè)類(lèi)。
循環(huán)神經(jīng)網(wǎng)絡(luò)用于文本分類(lèi)。類(lèi)似地,GNNs 應(yīng)用于圖結(jié)構(gòu),其中每個(gè)單詞是句子中的一個(gè)節(jié)點(diǎn)。
GNNs 是在卷積神經(jīng)網(wǎng)絡(luò)由于圖的任意大小和復(fù)雜結(jié)構(gòu)而無(wú)法取得最佳結(jié)果時(shí)引入的。
圖像由 Purvanshi Mehta 提供
輸入圖經(jīng)過(guò)一系列神經(jīng)網(wǎng)絡(luò)。輸入圖結(jié)構(gòu)被轉(zhuǎn)換成圖嵌入,允許我們保留關(guān)于節(jié)點(diǎn)、邊和全局上下文的信息。
然后,節(jié)點(diǎn) A 和 C 的特征向量通過(guò)神經(jīng)網(wǎng)絡(luò)層。它聚合這些特征并將它們傳遞到下一層。
4.1 圖神經(jīng)網(wǎng)絡(luò)的類(lèi)型
有幾種類(lèi)型的神經(jīng)網(wǎng)絡(luò),它們大多數(shù)都有一些卷積神經(jīng)網(wǎng)絡(luò)的變體。在本節(jié)中,我們將學(xué)習(xí)最流行的 GNNs。
圖卷積網(wǎng)絡(luò)(GCNs, Graph Convolutional Networks)類(lèi)似于傳統(tǒng)的 CNNs。它通過(guò)檢查相鄰節(jié)點(diǎn)來(lái)學(xué)習(xí)特征。GNNs 聚合節(jié)點(diǎn)向量,將結(jié)果傳遞給稠密層,并使用激活函數(shù)應(yīng)用非線性。簡(jiǎn)而言之,它包括圖卷積、線性層和非學(xué)習(xí)激活函數(shù)。有兩種主要類(lèi)型的 GCNs:空間卷積網(wǎng)絡(luò)和頻譜卷積網(wǎng)絡(luò)。
圖自編碼器網(wǎng)絡(luò)(Graph Auto-Encoder Networks)使用編碼器學(xué)習(xí)圖表示,并嘗試使用解碼器重建輸入圖。編碼器和解碼器通過(guò)瓶頸層連接。它們通常用于鏈路預(yù)測(cè),因?yàn)樽跃幋a器擅長(zhǎng)處理類(lèi)平衡問(wèn)題。
循環(huán)圖神經(jīng)網(wǎng)絡(luò)(RGNNs, Recurrent Graph Neural Networks)學(xué)習(xí)最佳擴(kuò)散模式,它們可以處理單個(gè)節(jié)點(diǎn)具有多個(gè)關(guān)系的多關(guān)系圖。這種類(lèi)型的圖神經(jīng)網(wǎng)絡(luò)使用正則化器來(lái)增強(qiáng)平滑性并消除過(guò)度參數(shù)化。RGNNs 使用更少的計(jì)算能力產(chǎn)生更好的結(jié)果。它們用于生成文本、機(jī)器翻譯、語(yǔ)音識(shí)別、生成圖像描述、視頻標(biāo)記和文本摘要。
門(mén)控圖神經(jīng)網(wǎng)絡(luò)(GGNNs, Gated Graph Neural Networks)在執(zhí)行具有長(zhǎng)期依賴性的任務(wù)方面優(yōu)于 RGNNs。門(mén)控圖神經(jīng)網(wǎng)絡(luò)通過(guò)在長(zhǎng)期依賴性上添加節(jié)點(diǎn)、邊和時(shí)間門(mén)來(lái)改進(jìn)循環(huán)圖神經(jīng)網(wǎng)絡(luò)。類(lèi)似于門(mén)控循環(huán)單元(GRUs),門(mén)用于在不同狀態(tài)下記住和忘記信息。
4.2 圖神經(jīng)網(wǎng)絡(luò)任務(wù)類(lèi)型
下面,我們列舉了一些圖神經(jīng)網(wǎng)絡(luò)任務(wù)類(lèi)型,并提供了示例:
圖分類(lèi)(Graph Classification:):用于將圖分類(lèi)為不同的類(lèi)別。其應(yīng)用包括社交網(wǎng)絡(luò)分析和文本分類(lèi)。
節(jié)點(diǎn)分類(lèi)(Node Classification:):這個(gè)任務(wù)使用相鄰節(jié)點(diǎn)的標(biāo)簽來(lái)預(yù)測(cè)圖中缺失的節(jié)點(diǎn)標(biāo)簽。
鏈路預(yù)測(cè)(Link Prediction):預(yù)測(cè)圖中具有不完整鄰接矩陣的一對(duì)節(jié)點(diǎn)之間的鏈接。這通常用于社交網(wǎng)絡(luò)。
社區(qū)檢測(cè)(Community Detection):基于邊的結(jié)構(gòu)將節(jié)點(diǎn)劃分為不同的群集。它類(lèi)似地從邊的權(quán)重、距離和圖對(duì)象中學(xué)習(xí)。
圖嵌入(Graph Embedding):將圖映射到向量,保留有關(guān)節(jié)點(diǎn)、邊和結(jié)構(gòu)的相關(guān)信息。
圖生成(Graph Generation:):從樣本圖分布中學(xué)習(xí),以生成一個(gè)新的但相似的圖結(jié)構(gòu)。
圖神經(jīng)網(wǎng)絡(luò)類(lèi)型
4.3 圖神經(jīng)網(wǎng)絡(luò)的缺點(diǎn)
使用 GNNs 存在一些缺點(diǎn)。了解這些缺點(diǎn)將幫助我們確定何時(shí)使用 GNN 以及如何優(yōu)化我們的機(jī)器學(xué)習(xí)模型的性能。
大多數(shù)神經(jīng)網(wǎng)絡(luò)可以深度學(xué)習(xí)以獲得更好的性能,而 GNNs 大多是淺層網(wǎng)絡(luò),主要有三層。這限制了我們?cè)诖笮蛿?shù)據(jù)集上獲得最先進(jìn)性能的能力。
圖結(jié)構(gòu)不斷變化,這使得在其上訓(xùn)練模型變得更加困難。
將模型部署到生產(chǎn)環(huán)境面臨可擴(kuò)展性問(wèn)題,因?yàn)檫@些網(wǎng)絡(luò)在計(jì)算上很昂貴。如果您有一個(gè)龐大且復(fù)雜的圖結(jié)構(gòu),將難以在生產(chǎn)環(huán)境中擴(kuò)展 GNNs。
5. 什么是圖卷積網(wǎng)絡(luò)(GCN)?
大多數(shù) GNNs 都是圖卷積網(wǎng)絡(luò),了解它們?cè)谶M(jìn)入節(jié)點(diǎn)分類(lèi)教程之前很重要。
GCN 中的卷積與卷積神經(jīng)網(wǎng)絡(luò)中的卷積相同。它將神經(jīng)元與權(quán)重(濾波器)相乘,以從數(shù)據(jù)特征中學(xué)習(xí)。
它在整個(gè)圖像上充當(dāng)滑動(dòng)窗口,以從相鄰單元中學(xué)習(xí)特征。該濾波器使用權(quán)重共享在圖像識(shí)別系統(tǒng)中學(xué)習(xí)各種面部特征。
現(xiàn)在將相同的功能轉(zhuǎn)移到圖卷積網(wǎng)絡(luò)中,其中模型從相鄰節(jié)點(diǎn)中學(xué)習(xí)特征。GCN 和 CNN 之間的主要區(qū)別在于,GCN 被設(shè)計(jì)為在非歐幾里得數(shù)據(jù)結(jié)構(gòu)上工作,其中節(jié)點(diǎn)和邊的順序可能變化。
CNN vs GCN
有兩種類(lèi)型的 GCNs:
空間圖卷積網(wǎng)絡(luò)(Spatial Graph Convolutional Networks)使用空間特征從位于空間空間的圖中學(xué)習(xí)。
頻譜圖卷積網(wǎng)絡(luò)(Spectral Graph Convolutional Networks)使用圖拉普拉斯矩陣的特征值分解進(jìn)行節(jié)點(diǎn)間的信息傳播。這些網(wǎng)絡(luò)靈感來(lái)自信號(hào)與系統(tǒng)中的波動(dòng)傳播。
6. 圖神經(jīng)網(wǎng)絡(luò)如何工作?使用 PyTorch 構(gòu)建圖神經(jīng)網(wǎng)絡(luò)
我們將構(gòu)建和訓(xùn)練用于節(jié)點(diǎn)分類(lèi)模型的譜圖卷積。代碼源可在文末獲取,讓您體驗(yàn)并運(yùn)行您的第一個(gè)基于圖的機(jī)器學(xué)習(xí)模型。
6.1 準(zhǔn)備
我們將安裝 Pytorch 軟件包,因?yàn)?pytorch_geometric 是在其基礎(chǔ)上構(gòu)建的。
?
?
!pip?install?-q?torch
?
?
然后,我們將使用 torch 版本安裝 torch-scatter 和 torch-sparse。之后,我們將從 GitHub 安裝 pytorch_geometric 的最新版本。
?
?
%%capture import?os import?torch os.environ['TORCH']?=?torch.__version__ os.environ['PYTHONWARNINGS']?=?"ignore" !pip?install?torch-scatter?-f?https://data.pyg.org/whl/torch-${TORCH}.html !pip?install?torch-sparse?-f?https://data.pyg.org/whl/torch-${TORCH}.html !pip?install?git+https://github.com/pyg-team/pytorch_geometric.git
?
?
6.2 Planetoid Cora 數(shù)據(jù)集
Planetoid 是來(lái)自 Cora、CiteSeer 和 PubMed 的引文網(wǎng)絡(luò)數(shù)據(jù)集。節(jié)點(diǎn)是具有 1433 維詞袋特征向量的文檔,邊是研究論文之間的引文鏈接。有 7 個(gè)類(lèi)別,我們將訓(xùn)練模型以預(yù)測(cè)缺失的標(biāo)簽。
我們將導(dǎo)入 Planetoid Cora 數(shù)據(jù)集,并對(duì)詞袋輸入特征進(jìn)行行標(biāo)準(zhǔn)化。之后,我們將分析數(shù)據(jù)集和第一個(gè)圖對(duì)象。
?
?
from?torch_geometric.datasets?import?Planetoid from?torch_geometric.transforms?import?NormalizeFeatures dataset?=?Planetoid(root='data/Planetoid',?name='Cora',?transform=NormalizeFeatures()) print(f'Dataset:?{dataset}:') print('======================') print(f'Number?of?graphs:?{len(dataset)}') print(f'Number?of?features:?{dataset.num_features}') print(f'Number?of?classes:?{dataset.num_classes}') data?=?dataset[0]??#?Get?the?first?graph?object. print(data)
?
?
Cora 數(shù)據(jù)集有 2708 個(gè)節(jié)點(diǎn)、10,556 條邊、1433 個(gè)特征和 7 個(gè)類(lèi)別。第一個(gè)對(duì)象有 2708 個(gè)訓(xùn)練、驗(yàn)證和測(cè)試掩碼。我們將使用這些掩碼來(lái)訓(xùn)練和評(píng)估模型。
?
?
Dataset:?Cora(): ====================== Number?of?graphs:?1 Number?of?features:?1433 Number?of?classes:?7 Data(x=[2708,?1433],?edge_index=[2,?10556],?y=[2708],?train_mask=[2708],?val_mask=[2708],?test_mask=[2708])
?
?
6.3 使用 GNN 進(jìn)行節(jié)點(diǎn)分類(lèi)
我們將創(chuàng)建一個(gè)包含兩個(gè) GCNConv 層、relu 激活和 0.5 的丟棄率的 GCN 模型結(jié)構(gòu)。該模型包含 16 個(gè)隱藏通道。
GCN 層:
上述方程中的 W(?+1) 是一個(gè)可訓(xùn)練的權(quán)重矩陣,Cw,v 表示每個(gè)邊的固定標(biāo)準(zhǔn)化系數(shù)。
?
?
from?torch_geometric.nn?import?GCNConv import?torch.nn.functional?as?F class?GCN(torch.nn.Module): ????def?__init__(self,?hidden_channels): ????????super().__init__() ????????torch.manual_seed(1234567) ????????self.conv1?=?GCNConv(dataset.num_features,?hidden_channels) ????????self.conv2?=?GCNConv(hidden_channels,?dataset.num_classes) ????def?forward(self,?x,?edge_index): ????????x?=?self.conv1(x,?edge_index) ????????x?=?x.relu() ????????x?=?F.dropout(x,?p=0.5,?training=self.training) ????????x?=?self.conv2(x,?edge_index) ????????return?x model?=?GCN(hidden_channels=16) print(model) >>>?GCN( ????(conv1):?GCNConv(1433,?16) ????(conv2):?GCNConv(16,?7) ??)
?
?
6.4 可視化未經(jīng)訓(xùn)練的 GCN 網(wǎng)絡(luò)
讓我們使用 sklearn.manifold.TSNE 和 matplotlib.pyplot 來(lái)可視化未經(jīng)訓(xùn)練的 GCN 網(wǎng)絡(luò)的節(jié)點(diǎn)嵌入。它將繪制一個(gè)包含 7 個(gè)維度節(jié)點(diǎn)嵌入的 2D 散點(diǎn)圖。
?
?
%matplotlib?inline import?matplotlib.pyplot?as?plt from?sklearn.manifold?import?TSNE def?visualize(h,?color): ????z?=?TSNE(n_components=2).fit_transform(h.detach().cpu().numpy()) ????plt.figure(figsize=(10,10)) ????plt.xticks([]) ????plt.yticks([]) ????plt.scatter(z[:,?0],?z[:,?1],?s=70,?c=color,?cmap="Set2") ????plt.show()
?
?
然后我們將評(píng)估模型,然后將訓(xùn)練數(shù)據(jù)添加到未經(jīng)訓(xùn)練的模型中以可視化各種節(jié)點(diǎn)和類(lèi)別。
?
?
model.eval() out?=?model(data.x,?data.edge_index) visualize(out,?color=data.y)
?
?
6.5 訓(xùn)練 GNN
我們將使用 Adam 優(yōu)化器和交叉熵?fù)p失函數(shù)(Cross-Entropy Loss)對(duì)模型進(jìn)行 100 輪訓(xùn)練。
在訓(xùn)練函數(shù)中,我們有:
清除梯度
執(zhí)行一次前向傳播
使用訓(xùn)練節(jié)點(diǎn)計(jì)算損失
計(jì)算梯度,并更新參數(shù)
在測(cè)試函數(shù)中,我們有:
預(yù)測(cè)節(jié)點(diǎn)類(lèi)別
提取具有最高概率的類(lèi)別標(biāo)簽
檢查有多少個(gè)值被正確預(yù)測(cè)
創(chuàng)建準(zhǔn)確率比,使用正確預(yù)測(cè)的總和除以節(jié)點(diǎn)的總數(shù)。
?
?
model?=?GCN(hidden_channels=16) optimizer?=?torch.optim.Adam(model.parameters(),?lr=0.01,?weight_decay=5e-4) criterion?=?torch.nn.CrossEntropyLoss() def?train(): ??????model.train() ??????optimizer.zero_grad() ??????out?=?model(data.x,?data.edge_index) ??????loss?=?criterion(out[data.train_mask],?data.y[data.train_mask]) ??????loss.backward() ??????optimizer.step() ??????return?loss def?test(): ??????model.eval() ??????out?=?model(data.x,?data.edge_index) ??????pred?=?out.argmax(dim=1) ??????test_correct?=?pred[data.test_mask]?==?data.y[data.test_mask] ??????test_acc?=?int(test_correct.sum())?/?int(data.test_mask.sum()) ??????return?test_acc for?epoch?in?range(1,?101): ????loss?=?train() ????print(f'Epoch:?{epoch:03d},?Loss:?{loss:.4f}') GAT( ??(conv1):?GATConv(1433,?8,?heads=8) ??(conv2):?GATConv(64,?7,?heads=8) ) ..?..?..?.. ..?..?..?.. Epoch:?098,?Loss:?0.5989 Epoch:?099,?Loss:?0.6021 Epoch:?100,?Loss:?0.5799
?
?
6.6 模型評(píng)估
我們將使用測(cè)試函數(shù)在未見(jiàn)過(guò)的數(shù)據(jù)集上評(píng)估模型,如您所見(jiàn),我們?cè)跍?zhǔn)確率上取得了相當(dāng)不錯(cuò)的結(jié)果,為 81.5%。
?
?
test_acc?=?test() print(f'Test?Accuracy:?{test_acc:.4f}')
?
?
輸出:
?
?
>>>?測(cè)試準(zhǔn)確率:0.8150
?
?
現(xiàn)在,我們將可視化經(jīng)過(guò)訓(xùn)練的模型的輸出嵌入以驗(yàn)證結(jié)果。
?
?
model.eval() out?=?model(data.x,?data.edge_index) visualize(out,?color=data.y)
?
?
正如我們所看到的,經(jīng)過(guò)訓(xùn)練的模型為相同類(lèi)別的節(jié)點(diǎn)產(chǎn)生了更好的聚類(lèi)。
6.7 訓(xùn)練 GATConv 模型
在第二個(gè)例子中,我們將使用 GATConv 層替換 GCNConv。圖注意力網(wǎng)絡(luò)使用掩碼的自注意力層來(lái)解決 GCNConv 的缺點(diǎn)并取得最先進(jìn)的結(jié)果。
您還可以嘗試其他 GNN 層,并嘗試不同的優(yōu)化、丟失率和隱藏通道數(shù)量,以獲得更好的性能。
在下面的代碼中,我們只是用具有 8 個(gè)注意力頭的 GATConv 替換了 GCNConv,其中第一層有 8 個(gè)頭,第二層有 1 個(gè)頭。
我們還將設(shè)置:
dropout為 0.6
隱藏通道為 8
學(xué)習(xí)率為 0.005
我們修改了測(cè)試函數(shù)以找到特定掩碼(驗(yàn)證、測(cè)試)的準(zhǔn)確率。這將幫助我們?cè)谀P陀?xùn)練期間打印出驗(yàn)證和測(cè)試分?jǐn)?shù)。我們還將驗(yàn)證和測(cè)試結(jié)果存儲(chǔ)到后面的繪圖線圖中。
?
?
from?torch_geometric.nn?import?GATConv class?GAT(torch.nn.Module): ????def?__init__(self,?hidden_channels,?heads): ????????super().__init__() ????????torch.manual_seed(1234567) ????????self.conv1?=?GATConv(dataset.num_features,?hidden_channels,heads) ????????self.conv2?=?GATConv(heads*hidden_channels,?dataset.num_classes,heads) ????def?forward(self,?x,?edge_index): ????????x?=?F.dropout(x,?p=0.6,?training=self.training) ????????x?=?self.conv1(x,?edge_index) ????????x?=?F.elu(x) ????????x?=?F.dropout(x,?p=0.6,?training=self.training) ????????x?=?self.conv2(x,?edge_index) ????????return?x model?=?GAT(hidden_channels=8,?heads=8) print(model) optimizer?=?torch.optim.Adam(model.parameters(),?lr=0.005,?weight_decay=5e-4) criterion?=?torch.nn.CrossEntropyLoss() def?train(): ??????model.train() ??????optimizer.zero_grad() ??????out?=?model(data.x,?data.edge_index) ??????loss?=?criterion(out[data.train_mask],?data.y[data.train_mask]) ??????loss.backward() ??????optimizer.step() ??????return?loss def?test(mask): ??????model.eval() ??????out?=?model(data.x,?data.edge_index) ??????pred?=?out.argmax(dim=1) ??????correct?=?pred[mask]?==?data.y[mask] ??????acc?=?int(correct.sum())?/?int(mask.sum()) ??????return?acc val_acc_all?=?[] test_acc_all?=?[] for?epoch?in?range(1,?101): ????loss?=?train() ????val_acc?=?test(data.val_mask) ????test_acc?=?test(data.test_mask) ????val_acc_all.append(val_acc) ????test_acc_all.append(test_acc) ????print(f'Epoch:?{epoch:03d},?Loss:?{loss:.4f},?Val:?{val_acc:.4f},?Test:?{test_acc:.4f}') ..?..?..?.. ..?..?..?.. Epoch:?098,?Loss:?1.1283,?Val:?0.7960,?Test:?0.8030 Epoch:?099,?Loss:?1.1352,?Val:?0.7940,?Test:?0.8050 Epoch:?100,?Loss:?1.1053,?Val:?0.7960,?Test:?0.8040
?
?
正如我們所觀察到的,我們的模型并沒(méi)有比 GCNConv 表現(xiàn)得更好。它需要進(jìn)行超參數(shù)優(yōu)化或更多輪次的訓(xùn)練才能取得最先進(jìn)的結(jié)果。
6.8 模型評(píng)估
在評(píng)估部分,我們使用 matplotlib.pyplot 的折線圖可視化驗(yàn)證和測(cè)試分?jǐn)?shù)。
?
?
import?numpy?as?np plt.figure(figsize=(12,8)) plt.plot(np.arange(1,?len(val_acc_all)?+?1),?val_acc_all,?label='Validation?accuracy',?c='blue') plt.plot(np.arange(1,?len(test_acc_all)?+?1),?test_acc_all,?label='Testing?accuracy',?c='red') plt.xlabel('Epochs') plt.ylabel('Accurarcy') plt.title('GATConv') plt.legend(loc='lower?right',?fontsize='x-large') plt.savefig('gat_loss.png') plt.show()
?
?
經(jīng)過(guò) 60 輪次,驗(yàn)證和測(cè)試準(zhǔn)確率達(dá)到了穩(wěn)定的值,約為 0.8+/-0.02。
再次,讓我們可視化 GATConv 模型的節(jié)點(diǎn)聚類(lèi)。
?
?
model.eval() out?=?model(data.x,?data.edge_index) visualize(out,?color=data.y)
?
?
正如我們所見(jiàn),GATConv 層在相同類(lèi)別的節(jié)點(diǎn)上產(chǎn)生了相同的聚類(lèi)結(jié)果。
我們可以通過(guò)添加第二個(gè)驗(yàn)證數(shù)據(jù)集來(lái)減少過(guò)擬合,并通過(guò)嘗試來(lái)自 pytoch_geometric 的各種 GCN 層來(lái)提高模型性能。
GNN 常見(jiàn)問(wèn)題
圖神經(jīng)網(wǎng)絡(luò)(GNN)用于什么?
圖神經(jīng)網(wǎng)絡(luò)直接應(yīng)用于圖數(shù)據(jù)集,您可以訓(xùn)練它們以預(yù)測(cè)節(jié)點(diǎn)、邊緣和與圖相關(guān)的任務(wù)。它用于圖和節(jié)點(diǎn)分類(lèi)、鏈路預(yù)測(cè)、圖聚類(lèi)和生成,以及圖像和文本分類(lèi)。
在圖神經(jīng)網(wǎng)絡(luò)中,什么是圖?
在圖神經(jīng)網(wǎng)絡(luò)中,圖是一種包含節(jié)點(diǎn)和節(jié)點(diǎn)之間連接(稱(chēng)為邊)的數(shù)據(jù)結(jié)構(gòu)。邊可以是有向的或無(wú)向的。它具有動(dòng)態(tài)形狀和多維結(jié)構(gòu)。例如,在社交媒體中,節(jié)點(diǎn)可以是您朋友群中的人,而邊則是您與每個(gè)人之間的關(guān)系。
圖神經(jīng)網(wǎng)絡(luò)有多強(qiáng)大?
在圖像和節(jié)點(diǎn)分類(lèi)方面,圖神經(jīng)網(wǎng)絡(luò)優(yōu)于典型的卷積神經(jīng)網(wǎng)絡(luò)(CNN)。許多圖神經(jīng)網(wǎng)絡(luò)的變體在節(jié)點(diǎn)和圖分類(lèi)任務(wù)中取得了最先進(jìn)的結(jié)果 - openreview.net。
神經(jīng)網(wǎng)絡(luò)是否使用圖論?
是的,神經(jīng)網(wǎng)絡(luò)與設(shè)計(jì)用于處理非歐幾里得數(shù)據(jù)的圖論密切相關(guān)。其中一些神經(jīng)網(wǎng)絡(luò)本身就是圖,或者輸出圖。
什么是圖卷積網(wǎng)絡(luò)?
圖卷積網(wǎng)絡(luò)類(lèi)似于用于圖數(shù)據(jù)集的卷積神經(jīng)網(wǎng)絡(luò)。它包括圖卷積、線性層和非線性激活。GNN 通過(guò)圖上的濾波器,檢查可用于對(duì)數(shù)據(jù)中的節(jié)點(diǎn)進(jìn)行分類(lèi)的節(jié)點(diǎn)和邊。
在深度學(xué)習(xí)中,什么是圖?
圖深度學(xué)習(xí)也被稱(chēng)為幾何深度學(xué)習(xí)。它使用多個(gè)神經(jīng)網(wǎng)絡(luò)層以實(shí)現(xiàn)更好的性能。這是一個(gè)活躍的研究領(lǐng)域,科學(xué)家們正試圖在不影響性能的情況下增加層數(shù)。
審核編輯:黃飛
評(píng)論
查看更多