0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

【BBuf的CUDA筆記】OpenAI Triton入門筆記一

jf_pmFSk4VX ? 來源:GiantPandaCV ? 2024-01-23 10:00 ? 次閱讀

0x1. OpenAI Triton介紹閱讀

這里來看官方的介紹:https://openai.com/research/triton ,從官方的介紹中我們可以看到OpenAI Triton的產生動機以及它的目標是什么,還可以看到一些經典算法的實現例子展示。

這里的標題是 Introducing Triton: Open-source GPU programming for neural networks ,翻譯就是《介紹 Triton:用于神經網絡的開源 GPU 編程語言》。然后下面的一句話翻譯過來是:我們發(fā)布了 Triton 1.0,這是一種開源的類 Python 編程語言,它使得沒有 CUDA 經驗的研究人員能夠編寫高效的 GPU 代碼——大多數情況下,其效能與專家所能編寫的代碼相當。這里指出了triton的目的,就是讓編寫cuda kernrl變得更簡單。接下來就逐步看一下介紹里的具體內容,為了更加準確這里會截圖對應的原文然后放上我的翻譯或者理解。

1240ae7a-b926-11ee-8b88-92fbcf53809c.png

這里的意思是Triton可以使得用戶用較少的努力就寫出一個達到硬件峰值性能的kernel,比如使用 Triton 可以編寫 FP16 矩陣乘法的核函數,其性能能夠匹配 cuBLAS,并且這個代碼不超過25行。然后研究者已經用Triton開發(fā)了一些高效的實現,和功能相同的Torch實現相比,性能可以達到兩倍提升。后面一段就是強調了使用CUDA來把一些原始的PyTorch實現寫一個算子一般會更加高效,但是這個難度不小,并且目前已有工作也不能很好覆蓋這種情況,所以OpenAI Triton誕生。

12582dde-b926-11ee-8b88-92fbcf53809c.png

這里講的是GPU編程的挑戰(zhàn),現代 GPU 的架構大致可以分為三個主要部分——DRAM、SRAM 和 ALU。在優(yōu)化 CUDA 代碼時,必須考慮到這些組件:

從 DRAM 的內存?zhèn)鬏敱仨毢喜⒊纱笮褪聞?,以利用現代內存接口的大總線寬度(內存合并訪問)。

數據必須在重復使用前手動存儲到 SRAM 中,并進行管理來最小化bank conflict。

計算必須仔細地進行劃分和調度,不僅是在流式多處理器(SMs)之間,還包括在其內部,以促進指令/線程級并行性,并利用專用的 ALU(例如,Tensor Cores)。

1280cc12-b926-11ee-8b88-92fbcf53809c.png1293ca2e-b926-11ee-8b88-92fbcf53809c.png

考慮所有這些因素可能對于擁有多年經驗的資深 CUDA 程序員來說都是一個挑戰(zhàn)。Triton 的目的是完全自動化這些優(yōu)化,以便開發(fā)者能夠更好地專注于他們并行代碼的高層邏輯。Triton 旨在廣泛適用,因此不會自動在流式多處理器(SMs)之間調度工作——留下一些重要的算法考慮(例如,tiling,跨 SM 同步)由開發(fā)者自行決定。

然后給了一個表格展示cuda的編譯器和triton的區(qū)別。

12aadc5a-b926-11ee-8b88-92fbcf53809c.png12c93416-b926-11ee-8b88-92fbcf53809c.png

在所有可用的領域特定語言和即時編譯器中,Triton可能和Numba最相似:kernel被定義為一個裝飾過的函數,并以不同的 program_id 并行啟動在所謂的網格實例上。然而,正如下面的代碼片段所示,相似之處僅此而已:Triton 通過對塊上的操作來暴露實例內部的并行性——這些小數組的尺寸是二的冪次方——而不是單指令多線程(SIMT)執(zhí)行模型。這樣做,Triton 有效地抽象出了所有與 CUDA 線程塊內部并發(fā)相關的問題(例如,內存合并、共享內存同步/沖突、Tensor Cores調度)。

12e5d9b8-b926-11ee-8b88-92fbcf53809c.png13058c90-b926-11ee-8b88-92fbcf53809c.png

wKgZomWvHkeAEQTkAADFqjaD4zM784.jpg

131b4224-b926-11ee-8b88-92fbcf53809c.png

注意,Triton 的即時編譯器將 X 和 Y 視為指針而不是張量;我們認為保留對內存訪問的低級控制對于處理更復雜的數據結構(例如,塊稀疏張量)是重要的。重要的是,這種特定的 softmax 實現在整個標準化過程中將 X 的行保留在 SRAM 中,這在適用時最大化了數據重用(約 <32K 列)。這與 PyTorch 的內部 CUDA 代碼不同,后者使用臨時內存使其更具通用性,但顯著更慢(如下所示)。這里的關鍵不是 Triton 本質上更好,而是它簡化了專用kernel的開發(fā),這些內核可能比在通用庫中找到的內核快得多。

1335ab3c-b926-11ee-8b88-92fbcf53809c.png

Torch(v1.9)JIT編譯器的較低性能凸顯了從高級張量操作序列自動生成 CUDA 代碼的難度。

1347b2dc-b926-11ee-8b88-92fbcf53809c.png13620100-b926-11ee-8b88-92fbcf53809c.png

這里是說Triton大概只需要25行Python代碼就可以實現一個接近峰值的矩陣乘法。(后面有專門的一大節(jié)講這個代碼的原理)代碼如下:

@triton.jit
defmatmul(A,B,C,M,N,K,stride_am,stride_ak,
stride_bk,stride_bn,stride_cm,stride_cn,
**META):
#extractmetaparameters
BLOCK_M,GROUP_M=META['BLOCK_M'],META['GROUP_M']
BLOCK_N=META['BLOCK_N']
BLOCK_K=META['BLOCK_K']
#programsaregroupedtogethertoimproveL2hitrate
_pid_m=tl.program_id(0)
_pid_n=tl.program_id(1)
pid_m=_pid_m//GROUP_M
pid_n=(_pid_n*GROUP_M)+(_pid_m%GROUP_M)
#rm(resp.rn)denotesarangeofindices
#forrows(resp.col)ofC
rm=pid_m*BLOCK_M+tl.arange(0,BLOCK_M)
rn=pid_n*BLOCK_N+tl.arange(0,BLOCK_N)
#rkdenotesarangeofindicesforcolumns
#(resp.rows)ofA(resp.B)
rk=tl.arange(0,BLOCK_K)
#thememoryaddressesofelementsinthefirstblockof
#AandBcanbecomputedusingnumpy-stylebroadcasting
A=A+(rm[:,None]*stride_am+rk[None,:]*stride_ak)
B=B+(rk[:,None]*stride_bk+rn[None,:]*stride_bn)
#initializeanditerativelyupdateaccumulator
acc=tl.zeros((BLOCK_M,BLOCK_N),dtype=tl.float32)
forkinrange(K,0,-BLOCK_K):
a=tl.load(A)
b=tl.load(B)
#blocklevelmatrixmultiplication
acc+=tl.dot(a,b)
#incrementpointerssothatthenextblocksofAandB
#areloadedduringthenextiteration
A+=BLOCK_K*stride_ak
B+=BLOCK_K*stride_bk
#fuseleakyReLUifdesired
#acc=tl.where(acc>=0,acc,alpha*acc)
#writebackresult
C=C+(rm[:,None]*stride_cm+rn[None,:]*stride_cn)
mask=(rm[:,None]

手寫矩陣乘法kernel的一個重要優(yōu)勢是,它們可以根據需要定制,以適應輸入(例如,切片)和輸出(例如,LeakyReLU)的融合轉換。如果沒有像 Triton 這樣的系統,沒有出色的 GPU 編程專長的開發(fā)者將無法進行矩陣乘法內核的定制修改。

1385810c-b926-11ee-8b88-92fbcf53809c.png1397ec0c-b926-11ee-8b88-92fbcf53809c.png

這里是說Triton 的良好性能源于一個以 Triton-IR 為中心的模塊化系統架構,Triton-IR 是一個基于 LLVM 的中間表示,在這個系統中,多維值塊(這個是MLIR的概念)是一等公民。GPT

@triton.jit 裝飾器的工作原理是遍歷提供的 Python 函數的抽象語法樹(AST),以便使用常見的 SSA 構建算法即時生成 Triton-IR。然后,編譯器后端會簡化、優(yōu)化并自動并行化所產生的 IR 代碼,再將其轉換為高質量的 LLVM-IR —— 最終生成 PTX —— 以在近期的 NVIDIA GPU 上執(zhí)行。目前不支持 CPUAMD GPU,但我們歡迎社區(qū)貢獻,旨在解決這一限制。

13c1ebd8-b926-11ee-8b88-92fbcf53809c.png

我們發(fā)現,通過 Triton-IR 使用塊級別程序表示,使我們的編譯器能夠自動執(zhí)行各種重要的程序優(yōu)化。例如,可以通過觀察計算密集型塊級操作(例如,tl.dot)的操作數,自動將數據暫存到共享內存中,并使用標準的活性分析技術進行分配和同步。

另一方面,如下所示,Triton 程序可以高效且自動地并行化,既可以(1)通過并發(fā)執(zhí)行不同的kernel實例在流式多處理器(SMs)間并行,也可以(2)通過分析每個塊級操作的迭代空間,并在不同的 SIMD 單元間適當分配,從而在 SMs 內部并行。

13d33b9a-b926-11ee-8b88-92fbcf53809c.png

0x2. 教程1 Vector Addition閱讀

13ec7ab0-b926-11ee-8b88-92fbcf53809c.png

意思是這一節(jié)教程會介紹Triton編程模型定義kernel的基本寫法,此外也會介紹一下怎么實現一個良好的benchmark測試。下面來看計算kernel實現,我把注釋改成中文了:

importtorch

importtriton
importtriton.languageastl

@triton.jit
defadd_kernel(x_ptr,#*指針*,指向第一個輸入向量。
y_ptr,#*指針*,指向第二個輸入向量。
output_ptr,#*指針*,指向輸出向量。
n_elements,#向量的大小。
BLOCK_SIZE:tl.constexpr,#每個程序應處理的元素數量。
#注意:`constexpr`這樣可以被用作形狀值。
):
#這里有多個“程序”處理不同的數據。我們在這里識別我們是哪一個程序:
pid=tl.program_id(axis=0)#我們使用一維啟動網格,所以軸是0。
#該程序將處理從初始數據偏移的輸入。
#例如,如果你有一個長度為256的向量和塊大小為64,那么程序
#將分別訪問元素[0:64,64:128,128:192,192:256]。
#注意偏移量是一個指針列表:
block_start=pid*BLOCK_SIZE
offsets=block_start+tl.arange(0,BLOCK_SIZE)
#創(chuàng)建一個掩碼以防止內存操作越界訪問。
mask=offsets

這里還聲明了一個輔助函數來(1)分配z張量,(2)使用適當的網格/塊大小排隊上面的kernel:

defadd(x:torch.Tensor,y:torch.Tensor):
#我們需要預分配輸出。
output=torch.empty_like(x)
assertx.is_cudaandy.is_cudaandoutput.is_cuda
n_elements=output.numel()
#SPMD啟動網格表示并行運行的kernel實例的數量。
#它類似于CUDA啟動網格。它可以是Tuple[int],也可以是Callable(metaparameters)->Tuple[int]。
#在這種情況下,我們使用一個1D網格,其大小是塊的數量:
grid=lambdameta:(triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
#注意:
#-每個torch.tensor對象都隱式地轉換為指向其第一個元素的指針。
#-使用`triton.jit`裝飾的函數可以用一個啟動網格索引來獲得可調用的GPU內核。
#-不要忘記將元參數作為關鍵字參數傳遞。
add_kernel[grid](x,y,output,n_elements,BLOCK_SIZE=1024)
#我們返回一個指向z的句柄,但是因為`torch.cuda.synchronize()`還沒有被調用,所以這時kernel仍然
#在異步運行。
returnoutput

我們現在可以使用上面定義的函數來計算兩個torch.tensor對象的逐元素求和,并測試其正確性:

torch.manual_seed(0)
size=98432
x=torch.rand(size,device='cuda')
y=torch.rand(size,device='cuda')
output_torch=x+y
output_triton=add(x,y)
print(output_torch)
print(output_triton)
print(f'Themaximumdifferencebetweentorchandtritonis'
f'{torch.max(torch.abs(output_torch-output_triton))}')

輸出:

tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0')
tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0')
Themaximumdifferencebetweentorchandtritonis0.0
13fa6076-b926-11ee-8b88-92fbcf53809c.png

我們可以對不同大小的向量進行自定義操作的性能基準測試,以了解它相對于PyTorch的表現如何。為了簡化操作,Triton提供了一系列內置工具,使我們能夠簡潔地繪制出自定義操作在不同問題規(guī)模下的性能圖表。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'],#用作繪圖x軸的參數名。
x_vals=[2**iforiinrange(12,28,1)],#`x_name`的不同可能值。
x_log=True,#x軸是對數的。
line_arg='provider',#其值對應于圖中不同線條的參數名。
line_vals=['triton','torch'],#`line_arg`的可能值。
line_names=['Triton','Torch'],#線條的標簽名稱。
styles=[('blue','-'),('green','-')],#線條樣式。
ylabel='GB/s',#y軸的標簽名稱。
plot_name='vector-add-performance',#繪圖的名稱。也用作保存繪圖的文件名。
args={},#不在`x_names`和`y_name`中的函數參數的值。
))
defbenchmark(size,provider):
x=torch.rand(size,device='cuda',dtype=torch.float32)
y=torch.rand(size,device='cuda',dtype=torch.float32)
quantiles=[0.5,0.2,0.8]
ifprovider=='torch':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:x+y,quantiles=quantiles)
ifprovider=='triton':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:add(x,y),quantiles=quantiles)
gbps=lambdams:12*size/ms*1e-6
returngbps(ms),gbps(max_ms),gbps(min_ms)

gbps = lambda ms: 12 * size / ms * 1e-6這里的12表示的是數據讀寫的bit,因為有x和y以及z的存在,所以是3*4=12bit?,F在可以運行上面的裝飾函數了。傳遞 print_data=True 參數來查看性能數據,傳遞 show_plots=True 參數來繪制圖表,和/或傳遞 save_path='/path/to/results/' 參數來將它們連同原始CSV數據一起保存到磁盤上:

benchmark.run(print_data=True,show_plots=True)

140b253c-b926-11ee-8b88-92fbcf53809c.png

可以看到,對于elementwise任務,Triton的性能幾乎和PyTorch持平,但是Triton寫起來很簡單。

0x3. 教程2 Fused Softmax閱讀

在這個教程中,我們將編寫一個融合的softmax操作,這個操作對于特定類型的矩陣來說比PyTorch的原生操作要快得多:那些行的大小可以放入GPU的SRAM中的矩陣。

通過這樣做,我們將學習到:

kernel融合對于帶寬受限操作的好處。

Triton中的reduce操作符。

動機

自定義GPU kernel用于逐元素加法在教育上是有價值的,但在實際應用中可能作用有限。讓我們考慮一個簡單的(數值穩(wěn)定的)softmax操作的情況:

importtorch

importtriton
importtriton.languageastl

@torch.jit.script
defnaive_softmax(x):
"""使用原生pytorch計算X的逐行softmax

我們減去最大元素是為了避免溢出。Softmax對這種偏移是不變的。
"""
#讀取MN個元素;寫入M個元素
x_max=x.max(dim=1)[0]
#讀取MN+M個元素;寫入MN個元素
z=x-x_max[:,None]
#讀取MN個元素;寫入MN個元素
numerator=torch.exp(z)
#讀取MN個元素;寫入M個元素
denominator=numerator.sum(dim=1)
#讀取MN+M個元素;寫入MN個元素
ret=numerator/denominator[:,None]
#總計:讀取5MN+2M個元素;寫入3MN+2M個元素
returnret

1421fea6-b926-11ee-8b88-92fbcf53809c.png

wKgaomWvHqyAFI-mAACokcxGs7Y255.jpg

計算kernel

我們的softmax kernel的工作方式如下:每個程序加載輸入矩陣X的一行,對其進行歸一化處理,然后將結果寫回到輸出Y中。需要注意的是,Triton的一個重要限制是每個塊必須包含2的冪次方個元素,因此如果我們想處理任何可能的輸入形狀,我們需要在內部對每行進行“pad”以及對內存訪問操作進行保護(也就是防止越界):

@triton.jit
defsoftmax_kernel(output_ptr,input_ptr,input_row_stride,output_row_stride,n_cols,BLOCK_SIZE:tl.constexpr):
#softmax的各行是獨立的,所以我們在這些行上進行并行處理
row_idx=tl.program_id(0)
#步長代表我們需要增加多少指針來前進1行
row_start_ptr=input_ptr+row_idx*input_row_stride
#塊大小是大于n_cols的下一個2的冪次,因此我們可以將每一行放入單個塊中
col_offsets=tl.arange(0,BLOCK_SIZE)
input_ptrs=row_start_ptr+col_offsets
#將行加載到SRAM中,使用掩碼因為BLOCK_SIZE可能大于n_cols
row=tl.load(input_ptrs,mask=col_offsets

解析來創(chuàng)建一個輔助函數,該函數為任何給定的輸入張量排隊執(zhí)行kernel并且設置了啟動參數。

defsoftmax(x):
n_rows,n_cols=x.shape
#塊大小是大于`x`中列數的最小2的冪
BLOCK_SIZE=triton.next_power_of_2(n_cols)
#我們可以使用的另一個技巧是要求編譯器通過增加每行分布的warp數(`num_warps`)來使用更多的線程。
#在下一個教程中,你將看到如何以更自然的方式自動調整這個值,這樣你就不必自己想出手動啟發(fā)式方法。
num_warps=4
ifBLOCK_SIZE>=2048:
num_warps=8
ifBLOCK_SIZE>=4096:
num_warps=16
#分配輸出
y=torch.empty_like(x)
#排隊執(zhí)行內核。一維啟動網格很簡單:我們有每行一個內核實例
#輸入矩陣
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
returny

14385156-b926-11ee-8b88-92fbcf53809c.png

這里是驗證Triton實現的fuse softmax和PyTorch的naive實現等價,顯然他們是等價的。

BenchMark

1449f802-b926-11ee-8b88-92fbcf53809c.png

這里設定矩陣的行數為固定的4096來做benchmark。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],#用作繪圖x軸的參數名
x_vals=[128*iforiinrange(2,100)],#`x_name`的不同可能值
line_arg='provider',#其值對應于圖中不同線條的參數名
line_vals=[
'triton',
'torch-native',
'torch-jit',
],#`line_arg`的可能值
line_names=[
"Triton",
"Torch(原生)",
"Torch(jit)",
],#線條的標簽名稱
styles=[('blue','-'),('green','-'),('green','--')],#線條樣式
ylabel="GB/s",#y軸的標簽名稱
plot_name="softmax-performance",#繪圖的名稱。也用作保存繪圖的文件名。
args={'M':4096},#不在`x_names`和`y_name`中的函數參數的值
))
defbenchmark(M,N,provider):
x=torch.randn(M,N,device='cuda',dtype=torch.float32)
quantiles=[0.5,0.2,0.8]
ifprovider=='torch-native':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.softmax(x,axis=-1),quantiles=quantiles)
ifprovider=='triton':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:softmax(x),quantiles=quantiles)
ifprovider=='torch-jit':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:naive_softmax(x),quantiles=quantiles)
gbps=lambdams:2*x.nelement()*x.element_size()*1e-9/(ms*1e-3)
returngbps(ms),gbps(max_ms),gbps(min_ms)

benchmark.run(show_plots=True,print_data=True)

14cc8aa6-b926-11ee-8b88-92fbcf53809c.png14de2f68-b926-11ee-8b88-92fbcf53809c.png

這里提到雖然Triton實現的softmax性能更好并且易于理解和維護,但PyTorch的torch.softmax則更加通用。

0x4. 教程3 Matrix Multiply閱讀

14fc667c-b926-11ee-8b88-92fbcf53809c.png

首先教程指出這里就是要寫一個Block級別的矩陣乘法,然后這里會涉及到多維度的指針操作,程序重排以更好的命中l(wèi)2 cache以及自動調優(yōu)。

動機

矩陣乘法是大多數現代高性能計算系統的關鍵構建塊。它們眾所周知難以優(yōu)化,因此它們的實現通常由硬件供應商自己作為所謂的“內核庫”(例如,cuBLAS)的一部分來完成。不幸的是,這些庫通常是專有的,無法輕易地定制以適應現代深度學習工作負載的需求(例如,融合激活函數)。在這個教程中,你將學習如何使用Triton自己實現高效的矩陣乘法,這種方法易于定制和擴展。

大致來說,我們將要編寫的內核將實現以下塊級算法來乘以一個 (M, K) 矩陣和一個 (K, N) 矩陣:

#Doinparallel
forminrange(0,M,BLOCK_SIZE_M):
#Doinparallel
forninrange(0,N,BLOCK_SIZE_N):
acc=zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=float32)
forkinrange(0,K,BLOCK_SIZE_K):
a=A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K]
b=B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]
acc+=dot(a,b)
C[m:m+BLOCK_SIZE_M,n:n+BLOCK_SIZE_N]=acc

其中,雙重嵌套的for循環(huán)的每次迭代都由一個專用的Triton program實例執(zhí)行。

計算kernel

上述算法實際上在Triton中相當容易實現。主要的難點來自于在內循環(huán)中計算必須讀取A和B塊的內存位置。為此,我們需要多維指針運算。

指針運算

對于一個2D Tensor X,X[i, j]的內存位置為&X[i, j] = X + i*stride_xi + j*stride_xj。因此,對于A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]和B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]的塊指針可以用下面的偽代碼定義:

&A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K]=a_ptr+(m:m+BLOCK_SIZE_M)[:,None]*A.stride(0)+(k:k+BLOCK_SIZE_K)[None,:]*A.stride(1);
&B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]=b_ptr+(k:k+BLOCK_SIZE_K)[:,None]*B.stride(0)+(n:n+BLOCK_SIZE_N)[None,:]*B.stride(1);

這意味著A和B塊的指針可以在Triton中初始化,比如 k=0 如下代碼所示。另外注意,我們需要一個額外的模運算來處理M不是BLOCK_SIZE_M的倍數或N不是BLOCK_SIZE_N的倍數的情況,在這種情況下,我們可以用一些無用的值填充數據,這些值不會對結果產生影響。對于K維度,我們稍后將使用掩碼加載語義來處理。

offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M
offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N
offs_k=tl.arange(0,BLOCK_SIZE_K)
a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak)
b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)

然后在內循環(huán)中按如下方式更新:

a_ptrs+=BLOCK_SIZE_K*stride_ak;
b_ptrs+=BLOCK_SIZE_K*stride_bk;

如上所述,每個program實例計算一個 [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計算順序是很重要的,因為它會影響我們程序的L2緩存命中率,不幸的是,一個簡單的行優(yōu)先順序是不夠的。

pid=triton.program_id(0);
grid_m=(M+BLOCK_SIZE_M-1)//BLOCK_SIZE_M;
grid_n=(N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N;
pid_m=pid/grid_n;
pid_n=pid%grid_n;

L2 Cache優(yōu)化

如上所述,每個程序實例計算一個 [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計算順序很重要,因為它會影響我們程序的L2緩存命中率,不幸的是,一個簡單的行主序排序是不夠的。

一個可能的解決方案是以一種促進數據重用的順序啟動塊。這可以通過在切換到下一列之前將塊在GROUP_M行的super group中分組來實現:

#程序ID
pid=tl.program_id(axis=0)
#沿M軸的程序ID數量
num_pid_m=tl.cdiv(M,BLOCK_SIZE_M)
#沿N軸的程序ID數量
num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)
#組中的程序數量
num_pid_in_group=GROUP_SIZE_M*num_pid_n
#該程序所在組的ID
group_id=pid//num_pid_in_group
#組中第一個程序的行ID
first_pid_m=group_id*GROUP_SIZE_M
#如果`num_pid_m`不能被`GROUP_SIZE_M`整除,最后一個組更小
group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M)
#*在組內*,程序按列主序排列
#程序在*啟動網格*中的行ID
pid_m=first_pid_m+(pid%group_size_m)
#程序在*啟動網格*中的列ID
pid_n=(pid%num_pid_in_group)//group_size_m

例如,在下面的矩陣乘法中,每個矩陣由9個塊乘以9個塊組成,我們可以看到,如果我們按行主序計算輸出,我們需要將90個塊加載到SRAM中以計算前9個輸出塊,但如果我們按grouped ordering進行計算,我們只需要加載54個塊。

15242360-b926-11ee-8b88-92fbcf53809c.png

在實際應用中,這可以在某些硬件架構上提高我們矩陣乘法內核的性能超過10%(例如,在A100上從220提升到245 TFLOPS)。

L2 Cache優(yōu)化原理補充講解

上面的group oredering的訪問代碼比較難理解,這里來更詳細的解析一下。

wKgaomWvHvSAfbS3AAB0PF-ZBkw397.jpg

#程序ID
pid=tl.program_id(axis=0)
#沿M軸的程序ID數量
num_pid_m=tl.cdiv(M,BLOCK_SIZE_M)
#沿N軸的程序ID數量
num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)

這里的num_pid_m和num_pid_n就是求分別要在M和N方向循環(huán)多少次。

然后上面圖中的黑色數字其實就可以理解為program id,我們可以看到program id增加的方向其實就代表了遍歷的ordering,對于row major來說就是在行方向上順序遍歷,而對于group ordering來說就是按照一個BLOCK_SIZE_M*BLOCK_SIZE_N這么大的一個小組來遍歷。其實這段代碼就是完成group ordering的遍歷:

num_pid_in_group=GROUP_SIZE_M*num_pid_n
group_id=pid//num_pid_in_group
first_pid_m=group_id*GROUP_SIZE_M
group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M)
pid_m=first_pid_m+(pid%group_size_m)
pid_n=(pid%num_pid_in_group)//group_size_m

以上面圖來看,num_pid_m=3,num_pid_n=3,num_pid_in_group=group_id * GROUP_SIZE_M=9*3=27,也就是下面的紅色框里面的program個數,從名字也可以看出來這個紅色框劃分的區(qū)域也是一個group。

1539a208-b926-11ee-8b88-92fbcf53809c.png

group_id 就表示當前的這次 "循環(huán)", 是在第幾個紅色框里,以program 0為例,這里為group_id = pid // num_pid_in_group=0//27=0。而first_pid_m 代表當前 group 中的第一個黃色program在全局的M維度上是第幾個program ,這里為first_pid_m = group_id * GROUP_SIZE_M=0,group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)這里是考慮到最后一個group可能占不滿數據(存在padding),所以就做一個截斷處理。

pid_m=first_pid_m+(pid%group_size_m)
pid_n=(pid%num_pid_in_group)//group_size_m

這兩行代碼計算當前的program處理的黃色小塊坐標([pid_m, pid_n]),pid_m這行是在行方向上移動,pid_n這行則是保證在上面的紅色框里面一定是一列一列來訪問的。

作為對比,在Row-major的方法中,訪問方式應該是這樣的:

pid_m=pid//num_pid_n
pid_n=pid%num_pid_n

計算最后的結果

有了上面的鋪墊,我們就可以計算最終的結果了,下面的代碼展示了完整的Triton 矩陣乘法kernel實現。

#使用`triton.jit`裝飾的函數可以通過`triton.autotune`裝飾器進行自動調優(yōu),該裝飾器包括:
#-一系列定義不同配置的`triton.Config`對象,
#這些配置涉及元參數(例如`BLOCK_SIZE_M`)和編譯選項(例如`num_warps`)的不同設置
#-一個自動調優(yōu)*關鍵字*,其值的變化將觸發(fā)對所有
#提供的配置的評估
@triton.autotune(
configs=[
#每個Config定義了一組特定的配置參數和編譯選項
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5,
num_warps=2),
],
key=['M','N','K'],#自動調優(yōu)關鍵字
)
@triton.jit
defmatmul_kernel(
#指向矩陣的指針
a_ptr,b_ptr,c_ptr,
#矩陣維度
M,N,K,
#步長變量表示在特定維度上移動1個元素時指針增加的量。
#例如`stride_am`是將`a_ptr`增加多少以獲取下一行的元素(A有M行)。
stride_am,stride_ak,#A矩陣的步長
stride_bk,stride_bn,#B矩陣的步長
stride_cm,stride_cn,#C矩陣的步長
#元參數
BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr,#
GROUP_SIZE_M:tl.constexpr,#
ACTIVATION:tl.constexpr#激活函數
):
"""用于計算矩陣乘法C=AxB的內核。
A的形狀為(M,K),B的形狀為(K,N),C的形狀為(M,N)。
"""
#-----------------------------------------------------------
#將程序ID`pid`映射到它應該計算的C矩陣的塊。
#這是以groupedordering完成的,以促進L2數據重用。
#詳細解釋看一節(jié)
pid=tl.program_id(axis=0)
num_pid_m=tl.cdiv(M,BLOCK_SIZE_M)
num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)
num_pid_in_group=GROUP_SIZE_M*num_pid_n
group_id=pid//num_pid_in_group
first_pid_m=group_id*GROUP_SIZE_M
group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M)
pid_m=first_pid_m+(pid%group_size_m)
pid_n=(pid%num_pid_in_group)//group_size_m

#----------------------------------------------------------
#為A和B的第一個塊創(chuàng)建指針。
#我們將在K方向移動時推進這個指針并累加
#`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針
#`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針
#有關詳細信息,請參閱上方“指針算術”部分
offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M
offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N
offs_k=tl.arange(0,BLOCK_SIZE_K)
a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak)
b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)

#-----------------------------------------------------------
#迭代以計算C矩陣的一個塊。
#我們將累加到一個`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊
#的fp32值以獲得更高的精度。
#`accumulator`在循環(huán)后會轉換回fp16。
accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32)
forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)):
#LoadthenextblockofAandB,generateamaskbycheckingtheKdimension.
#Ifitisoutofbounds,setitto0.
a=tl.load(a_ptrs,mask=offs_k[None,:]=0,x,0.01*x)

我們現在可以創(chuàng)建一個方便的封裝函數,它只需要兩個輸入張量,并且會:(1)檢查任何形狀約束;(2)分配輸出;(3)啟動上述kernel。

defmatmul(a,b,activation=""):
#Checkconstraints.
asserta.shape[1]==b.shape[0],"Incompatibledimensions"
asserta.is_contiguous(),"MatrixAmustbecontiguous"
assertb.is_contiguous(),"MatrixBmustbecontiguous"
M,K=a.shape
K,N=b.shape
#Allocatesoutput.
c=torch.empty((M,N),device=a.device,dtype=a.dtype)
#1Dlaunchkernelwhereeachblockgetsitsownprogram.
grid=lambdaMETA:(triton.cdiv(M,META['BLOCK_SIZE_M'])*triton.cdiv(N,META['BLOCK_SIZE_N']),)
matmul_kernel[grid](
a,b,c,#
M,N,K,#
a.stride(0),a.stride(1),#
b.stride(0),b.stride(1),#
c.stride(0),c.stride(1),#
ACTIVATION=activation#
)
returnc

計算過程的補充說明

上面的《L2 Cache優(yōu)化原理補充講解》這一節(jié)明確了kernel的group ordering的訪問方式以及實現,現在來看對于當前的program實例具體是怎么計算的。現在以計算C中的第一個Block的(0, 0)為例子,它需要從A和B分別加載9個黃色的小塊數據相乘并累加最后得到C中的(0, 0)位置結果。如下圖所示:

154fc970-b926-11ee-8b88-92fbcf53809c.png

下面的代碼先把program實例當前要處理A和B的第一個Block加載上來:

#----------------------------------------------------------
#為A和B的第一個塊創(chuàng)建指針。
#我們將在K方向移動時推進這個指針并累加
#`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針
#`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針
#有關詳細信息,請參閱上方“指針算術”部分
offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M
offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N
offs_k=tl.arange(0,BLOCK_SIZE_K)
a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak)
b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)

這里的a_ptr 是整個 A 矩陣第一個元素的地址,offs_am和offs_bn表示當前的program id在M維度和K維度的坐標,這個坐標是一個list,用tl.arange(0, BLOCK_SIZE_K)來獲取。

得到 M 維度 和 K 維度的坐標后, 就可以讓它們各自和 M 維度 和 K 維度的 stride 相乘, 然后和 a_ptr 相加, 就可以得到 A 矩陣 9 個 block 中第一個 block 中每個元素的地址了。 b_ptr也是同理。

最后一部分就是累加了,這里會在K維度上進行累加,每次計算輸出的一個塊。

#迭代以計算C矩陣的一個塊。
#我們將累加到一個`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊
#的fp32值以獲得更高的精度。
#`accumulator`在循環(huán)后會轉換回fp16。
accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32)
forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)):
#LoadthenextblockofAandB,generateamaskbycheckingtheKdimension.
#Ifitisoutofbounds,setitto0.
a=tl.load(a_ptrs,mask=offs_k[None,:]

這行代碼a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)考慮到 K 可能不能被 BLOCK_SIZE_K 整除, 到每一行最后一個 block 的時候, 實際大小是不足 BLOCK_SIZE_K 的,所以需要把超出的那部分元素mask掉。

最后這部分代碼是把當前的算子和LeakyReLU激活函數進行融合:

#當累加器仍然是FP32時,可以融合任意激活函數
ifACTIVATION=="leaky_relu":
accumulator=leaky_relu(accumulator)
c=accumulator.to(tl.float16)

單元測試

155dcdb8-b926-11ee-8b88-92fbcf53809c.png

Benchmark

這里使用一個方陣來對比Triton實現的matmul kernel和cublas的matmul kernel的性能。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M','N','K'],#用作圖表x軸的參數名
x_vals=[128*iforiinrange(2,33)],#`x_name`的不同可能值
line_arg='provider',#其值對應于圖表中不同線條的參數名
#`line_arg`的可能值
line_vals=['cublas','triton'],
#線條的標簽名稱
line_names=["cuBLAS","Triton"],
#線條樣式
styles=[('green','-'),('blue','-')],
ylabel="TFLOPS",#y軸的標簽名稱
plot_name="matmul-performance",#圖表的名稱,也用作保存圖表的文件名。
args={},#其他參數
))
defbenchmark(M,N,K,provider):
#初始化張量
a=torch.randn((M,K),device='cuda',dtype=torch.float16)
b=torch.randn((K,N),device='cuda',dtype=torch.float16)
quantiles=[0.5,0.2,0.8]#分位數
#如果提供者是cublas
ifprovider=='cublas':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles)
#如果提供者是triton
ifprovider=='triton':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles)
#性能計算函數
perf=lambdams:2*M*N*K*1e-12/(ms*1e-3)
returnperf(ms),perf(max_ms),perf(min_ms)

#運行基準測試,展示圖表和打印數據
benchmark.run(show_plots=True,print_data=True)

15816296-b926-11ee-8b88-92fbcf53809c.png

可以看到基于Triton實現的矩陣乘kernel性能大體可以和高度優(yōu)化的cuBlas持平。





審核編輯:劉清

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規(guī)問題,請聯系本站處理。 舉報投訴
  • sram
    +關注

    關注

    6

    文章

    757

    瀏覽量

    114450
  • 多處理器
    +關注

    關注

    0

    文章

    22

    瀏覽量

    8895
  • Cache
    +關注

    關注

    0

    文章

    128

    瀏覽量

    28188
  • python
    +關注

    關注

    53

    文章

    4753

    瀏覽量

    84080
  • OpenAI
    +關注

    關注

    9

    文章

    988

    瀏覽量

    6252

原文標題:【BBuf的CUDA筆記】十三,OpenAI Triton 入門筆記一

文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    嵌入式linux入門筆記

    嵌入式linux入門筆記
    發(fā)表于 08-13 16:06

    嵌入式linux入門筆記

    嵌入式linux入門筆記
    發(fā)表于 08-20 20:53

    嵌入式入門筆記

    嵌入式入門筆記。。。
    發(fā)表于 10-31 23:13

    嵌入式入門筆記

    嵌入式入門筆記 ,初學者可以學習下,很值得借鑒
    發(fā)表于 05-21 12:54

    圖書推薦:IAR-for-AVR入門學習筆記

    IAR-for-AVR入門學習筆記
    發(fā)表于 06-12 13:46

    求CSS入門的學習筆記

    CSS入門 學習筆記4
    發(fā)表于 06-04 15:15

    什么是CUDA

    的時間盡可能清晰的了解這個深度學習賴以實現的基礎概念。本文在以下資料的基礎上整理完成,感謝以下前輩提供的資料:CUDA——“從入門到放棄”我的CUDA學習之旅——啟程介紹篇不錯的
    發(fā)表于 07-26 06:28

    筆記本如何與投影機鏈接入門應用小技巧

    筆記本如何與投影機鏈接入門應用小技巧 、投影機連接筆記本電腦,無輸出影像?   答:筆記本電腦外接
    發(fā)表于 01-18 09:50 ?551次閱讀

    嵌入式入門筆記

    本文提供了嵌入式入門筆記,希望對你的學習有所幫助!
    發(fā)表于 06-07 16:57 ?0次下載
    嵌入式<b class='flag-5'>入門</b><b class='flag-5'>筆記</b>

    英飛凌MCU新手入門應用筆記(中文版)

    英飛凌MCU新手入門應用筆記(中文版)
    發(fā)表于 06-25 12:04 ?0次下載
    英飛凌MCU新手<b class='flag-5'>入門</b>應用<b class='flag-5'>筆記</b>(中文版)

    ARM入門調試筆記

    ARM入門調試筆記
    發(fā)表于 10-13 14:26 ?11次下載
    ARM<b class='flag-5'>入門</b>調試<b class='flag-5'>筆記</b>

    CUDA學習筆記篇:個基本的CUDA C程序

    1、CUDA的簡介 2、GPU架構和CUDA介紹3、CUDA架構4、開發(fā)環(huán)境說明和配置5、開始第個Hello CUDA程序????5.1、
    的頭像 發(fā)表于 12-14 23:40 ?807次閱讀

    Xilinx_Vivado_zynq7000入門筆記

    Xilinx_Vivado_zynq7000入門筆記說明。
    發(fā)表于 04-08 11:48 ?71次下載

    RT-Thread Nano入門學習筆記

    RT-Thread Nano入門學習筆記
    發(fā)表于 11-26 12:36 ?20次下載
    RT-Thread Nano<b class='flag-5'>入門</b>學習<b class='flag-5'>筆記</b>

    入門級微波電路(MMIC)的筆記-S 參數

    入門級微波電路(MMIC)的上課筆記-S 參數
    的頭像 發(fā)表于 07-05 10:13 ?584次閱讀
    <b class='flag-5'>入門</b>級微波電路(MMIC)的<b class='flag-5'>筆記</b>-S 參數