作者:京東科技 賈世聞
RAG(Retrieval-Augmented Generation)技術(shù)在AI生態(tài)系統(tǒng)中扮演著至關(guān)重要的角色,特別是在提升大型語言模型(LLMs)的準(zhǔn)確性和應(yīng)用范圍方面。RAG通過結(jié)合檢索技術(shù)與LLM提示,從各種數(shù)據(jù)源檢索相關(guān)信息,并將其與用戶的問題結(jié)合,生成準(zhǔn)確且豐富的回答。這一機(jī)制特別適用于需要應(yīng)對(duì)信息不斷更新的場景,因?yàn)榇笳Z言模型所依賴的參數(shù)知識(shí)本質(zhì)上是靜態(tài)的。
RAG技術(shù)的優(yōu)勢在于它能夠利用外部知識(shí)庫,引用大量的信息,以提供更深入、準(zhǔn)確且有價(jià)值的答案,提高了生成文本的可靠性。此外,RAG模型具備檢索庫的更新機(jī)制,可以實(shí)現(xiàn)知識(shí)的即時(shí)更新,無需重新訓(xùn)練模型,這在及時(shí)性要求高的應(yīng)用中占優(yōu)勢。
目前構(gòu)建一個(gè)RAG并不是一個(gè)非常的事情。使用Langchain等成熟技術(shù)架構(gòu)百十行代碼就能構(gòu)建一個(gè)Demo。那能不能利用目前的Rust生態(tài)構(gòu)建一個(gè)簡易的RAG。說干就干,本期和大家聊聊如果使用rust語言構(gòu)建rag。
構(gòu)建知識(shí)庫
知識(shí)庫構(gòu)建主要是模型+向量庫,為了保證所有系統(tǒng)中所有組件都使用rust構(gòu)建,在限量數(shù)據(jù)庫的選型上我們使用qdrant,純r(jià)ust構(gòu)建的向量數(shù)據(jù)庫。
知識(shí)庫的構(gòu)建最重要的步驟是embedding的過程。
過程如下:
模型加載
獲取文本token
通過模型獲取文本的Embedding
下面詳細(xì)介紹每個(gè)過程細(xì)節(jié)及代碼實(shí)現(xiàn)。
模型加載
以下代碼用于加載模型和tokenizer
async fn build_model_and_tokenizer(model_config: &ConfigModel) -> Result(BertModel, Tokenizer)?> { let device = Device::new_cuda(0)?; let repo = Repo::with_revision( model_config.model_id.clone(), RepoType::Model, model_config.revision.clone(), ); let (config_filename, tokenizer_filename, weights_filename) = { let api = ApiBuilder::new() .build()?; let api = api.repo(repo); let config = api.get("config.json").await?; let tokenizer = api.get("tokenizer.json").await?; let weights = if model_config.use_pth { api.get("pytorch_model.bin").await? } else { api.get("model.safetensors").await? }; (config, tokenizer, weights)A }; let config = std::fs::read_to_string(config_filename)?; let mut config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let vb = if model_config.use_pth { VarBuilder::from_pth(&weights_filename, DTYPE, &device)? } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; if model_config.approximate_gelu { config.hidden_act = HiddenAct::GeluApproximate; } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) }
模型和tokenizer是系統(tǒng)中頻繁調(diào)用的部分,所以為了避免重復(fù)加載,通過OnceCell構(gòu)建靜態(tài)全局變量
pub static GLOBAL_EMBEDDING_MODEL: OnceCell> = OnceCell::const_new(); pub async fn init_model_and_tokenizer() -> Arc(BertModel, Tokenizer)?> { let config = get_config().unwrap(); let (m, t) = build_model_and_tokenizer(&config.model).await.unwrap(); Arc::new((m, t)) }
在系統(tǒng)啟動(dòng)時(shí)加載模型
GLOBAL_RUNTIME.block_on(async { log::info!("global runtime start!"); // 加載model GLOBAL_EMBEDDING_MODEL .get_or_init(init_model_and_tokenizer) .await; });
Embedding 過程主要由一下函數(shù)實(shí)現(xiàn)。
pub async fn embedding_setence(content: &str) -> Result>> { let m_t = GLOBAL_EMBEDDING_MODEL.get().unwrap(); let tokens = m_t .1 .encode(content, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &m_t.0.device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?; let sequence_output = m_t.0.forward(&token_ids, &token_type_ids)?; let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?; let embeddings = (sequence_output.sum(1)? / (n_tokens as f64))?; let embeddings = normalize_l2(&embeddings)?; let encodings = embeddings.to_vec2::()?; Ok(encodings) }
函數(shù)通過tokenizer encode輸入的文本,再使用模型embed token 獲取一個(gè)三維的Tensor,最后歸一化張量。
數(shù)據(jù)入庫
知識(shí)庫構(gòu)建是將待檢索文本向量化后存儲(chǔ)到向量數(shù)據(jù)庫的過程。
本次使用京東云文檔作為原始文本,加工為以下格式。數(shù)據(jù)加工過程這里就不累述了。
{ "content": "# 服務(wù)計(jì)費(fèi)nn主機(jī)遷移服務(wù)自身為免費(fèi)服務(wù),但是遷移目標(biāo)為云主機(jī)鏡像時(shí),遷移過程依賴系統(tǒng)自動(dòng)創(chuàng)建的 中轉(zhuǎn)資源的配合,這些資源中涉及部分付費(fèi)資源,會(huì)產(chǎn)生相應(yīng)費(fèi)用。nn遷移過程涉及的中轉(zhuǎn)資付費(fèi)資源配置及計(jì)費(fèi)說明如下(單個(gè)遷移任務(wù)):nn| | 云主機(jī) | 云硬盤 | 彈性公網(wǎng)IP |n| --- | --- | --- | ------ |n| 計(jì)費(fèi)類型 | 按配置 | 按配置 | 按用量 |n| 規(guī)格配置 | 2C4G (c.n2.large或c.n3.large或c.n1.large) | 系統(tǒng)盤:40G 通用型SSD 數(shù)據(jù)盤:通用型SSD,數(shù)量及容量取決于源服務(wù)器系統(tǒng)盤及數(shù)據(jù)盤情況 | 30Mbps |n| 費(fèi)用預(yù)估 | 云主機(jī)規(guī)格每小時(shí)價(jià)格\\*遷移時(shí)長 | 云硬盤規(guī)格每小時(shí)價(jià)格\\*遷移時(shí)長 | 彈性公網(wǎng)IP每小時(shí)保有費(fèi)\\*遷移時(shí)長 僅使用彈性公網(wǎng)IP入方向流量,只涉及IP保有用,不涉及流量費(fèi)用 |nn> 提示:n>n> * 遷移時(shí)長取決于源服務(wù)器遷數(shù)據(jù)量以及源服務(wù)器公網(wǎng)出方向帶寬,公網(wǎng)連接順暢且源服務(wù)器公網(wǎng)出方向帶寬不低于22.5Mbps的情況下(主機(jī)遷移為單線程傳輸,京東云云主機(jī)在單流傳輸下實(shí)際帶寬為帶寬上限的75%左右),實(shí)際數(shù)據(jù)容量為5GB的磁盤遷移時(shí)長在30分鐘左右。n> * 中轉(zhuǎn)實(shí)例實(shí)例綁定的安全組出方向默認(rèn)拒絕所有流量,因此默認(rèn)情況下降不會(huì)產(chǎn)生任何公網(wǎng)出方向收費(fèi)流量,但此配置也影響了云主機(jī)部分監(jiān)控指標(biāo)的上報(bào),如需要監(jiān)控中轉(zhuǎn)實(shí)例的全部監(jiān)控?cái)?shù)據(jù),可自行調(diào)整安全組規(guī)則方向出方向443端口。", "title": "服務(wù)計(jì)費(fèi)說明", "product": "云主機(jī) CVM", "url": "https://docs.jdcloud.com/cn/virtual-machines/server-migration-service/billing" }
入庫完整代碼如下:
use anyhow::Error as E; use anyhow::Result; use candle_core::Device; use candle_core::Tensor; use candle_nn::VarBuilder; use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; use hf_hub::{api::tokio::Api, Repo, RepoType}; use qdrant_client::qdrant::CollectionExistsRequest; use qdrant_client::qdrant::CreateCollectionBuilder; use qdrant_client::qdrant::DeleteCollection; use qdrant_client::qdrant::Distance; use qdrant_client::qdrant::UpsertPointsBuilder; use qdrant_client::qdrant::VectorParamsBuilder; use qdrant_client::Payload; use qdrant_client::{ qdrant::{ CollectionOperationResponse, CreateCollection, PointStruct, PointsOperationResponse, UpsertPoints, }, Qdrant, }; use serde::{Deserialize, Serialize}; use serde_json::from_str; use std::fs; use std::sync::Arc; use tokenizers::Tokenizer; use tokio::sync::OnceCell; use uuid::Uuid; use walkdir::WalkDir; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Doc { pub content: String, pub title: String, pub product: String, pub url: String, } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct ModelConfig { #[serde(default = "ModelConfig::model_id_default")] pub model_id: String, #[serde(default = "ModelConfig::revision_default")] pub revision: String, #[serde(default = "ModelConfig::use_pth_default")] pub use_pth: bool, #[serde(default = "ModelConfig::approximate_gelu_default")] pub approximate_gelu: bool, } impl Default for ModelConfig { fn default() -> Self { Self { model_id: Self::model_id_default(), revision: Self::revision_default(), use_pth: Self::use_pth_default(), approximate_gelu: Self::approximate_gelu_default(), } } } impl ModelConfig { fn model_id_default() -> String { "moka-ai/m3e-large".to_string() } fn revision_default() -> String { "main".to_string() } fn use_pth_default() -> bool { true } fn approximate_gelu_default() -> bool { false } } pub static GLOBAL_MODEL: OnceCell> = OnceCell::const_new(); pub static GLOBAL_TOKEN: OnceCell> = OnceCell::const_new(); pub async fn init_model() -> Arc { let config = ModelConfig::default(); let (m, _) = build_model_and_tokenizer(&config).await.unwrap(); Arc::new(m) } pub async fn init_tokenizer() -> Arc { let config = ModelConfig::default(); let (_, t) = build_model_and_tokenizer(&config).await.unwrap(); Arc::new(t) } async fn build_model_and_tokenizer(model_config: &ModelConfig) -> Result(BertModel, Tokenizer)?> { let device = Device::new_cuda(0)?; let repo = Repo::with_revision( model_config.model_id.clone(), RepoType::Model, model_config.revision.clone(), ); let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; let api = api.repo(repo); let config = api.get("config.json").await?; let tokenizer = api.get("tokenizer.json").await?; let weights = if model_config.use_pth { api.get("pytorch_model.bin").await? } else { api.get("model.safetensors").await? }; (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; let mut config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let vb = if model_config.use_pth { VarBuilder::from_pth(&weights_filename, DTYPE, &device)? } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; if model_config.approximate_gelu { config.hidden_act = HiddenAct::GeluApproximate; } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } pub async fn embedding_setence(content: &str) -> Result>> { let m = GLOBAL_MODEL.get().unwrap(); let t = GLOBAL_TOKEN.get().unwrap(); let tokens = t.encode(content, true).map_err(E::msg)?.get_ids().to_vec(); let token_ids = Tensor::new(&tokens[..], &m.device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?; let sequence_output = m.forward(&token_ids, &token_type_ids)?; let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?; let embeddings = (sequence_output.sum(1).unwrap() / (n_tokens as f64))?; let embeddings = normalize_l2(&embeddings).unwrap(); let encodings = embeddings.to_vec2::()?; Ok(encodings) } pub fn normalize_l2(v: &Tensor) -> Result { Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) } pub struct QdrantClient { client: Qdrant, } impl QdrantClient { pub async fn create_collection( &self, request: impl Into, ) -> Result { let resp = self.client.create_collection(request).await?; Ok(resp) } pub async fn delete_collection( &self, request: impl Into, ) -> Result { let resp = self.client.delete_collection(request).await?; Ok(resp) } pub async fn collection_exists( &self, request: impl Into, ) -> Result { let resp = self.client.collection_exists(request).await?; Ok(resp) } pub async fn load_dir(&self, path: &str, collection_name: &str) { let mut points = vec![]; for entry in WalkDir::new(path) .into_iter() .filter_map(Result::ok) .filter(|e| !e.file_type().is_dir() && e.file_name().to_str().is_some()) { if let Some(p) = entry.path().to_str() { let id = Uuid::new_v4(); let content = match fs::read_to_string(p) { Ok(c) => c, Err(_) => continue, }; let doc = match from_str::(content.as_str()) { Ok(d) => d, Err(_) => continue, }; let mut payload = Payload::new(); payload.insert("content", doc.content); payload.insert("title", doc.title); payload.insert("product", doc.product); payload.insert("url", doc.url); let vector_contens = embedding_setence(content.as_str()).await.unwrap(); let ps = PointStruct::new(id.to_string(), vector_contens[0].clone(), payload); points.push(ps); if points.len().eq(&100) { let p = points.clone(); self.client .upsert_points(UpsertPointsBuilder::new(collection_name, p).wait(true)) .await .unwrap(); points.clear(); println!("batch finish"); } } } if points.len().gt(&0) { self.client .upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true)) .await .unwrap(); } } } #[tokio::main] async fn main() { // 加載模型 GLOBAL_MODEL.get_or_init(init_model).await; GLOBAL_TOKEN.get_or_init(init_tokenizer).await; let collection_name = "default_collection"; // The Rust client uses Qdrant's GRPC interface let qdrant = Qdrant::from_url("http://localhost:6334").build().unwrap(); let qdrant_client = QdrantClient { client: qdrant }; if !qdrant_client .collection_exists(collection_name) .await .unwrap() { qdrant_client .create_collection( CreateCollectionBuilder::new(collection_name) .vectors_config(VectorParamsBuilder::new(1024, Distance::Dot)), ) .await .unwrap(); } qdrant_client .load_dir("/root/jd_docs", collection_name) .await; println!("{:?}", qdrant_client.client.health_check().await); }
以上代碼要完成的任務(wù)如下:
推理服務(wù)
推理服務(wù)使用 rust 構(gòu)建的 mistral.rs。
由于國內(nèi)訪問hf 并不方便所以先通過 https://hf-mirror.com/ 現(xiàn)將模型下載到本地。本次使用qwen模型
HF_ENDPOINT="https://hf-mirror.com" huggingface-cli download --repo-type model --resume-download Qwen/Qwen2-7B --local-dir /root/Qwen2-7B
啟動(dòng) mistralrs-server
git clone https://github.com/EricLBuehler/mistral.rs cd mistral.rs cargo run --bin mistralrs-server --features cuda -- --port 3333 plain -m /root/Qwen2-7B -a qwen2
推理服務(wù)調(diào)用
mistral.rs 支持 Openai 的 api接口,使用 openai-api-rs調(diào)用即可。推理時(shí)間比較長 timeout 要設(shè)置長一些,若timeout 時(shí)間太短有可能不等返回結(jié)果就已經(jīng)強(qiáng)制超時(shí)。
pub static GLOBAL_OPENAI_CLIENT: Lazy> = Lazy::new(|| { let mut client = OpenAIClient::new_with_endpoint("http://10.0.0.7:3333/v1".to_string(), "EMPTY".to_string()); client.timeout = Some(30); Arc::new(client) }); pub async fn inference(content: &str, max_len: i64) -> Result> { let req = ChatCompletionRequest::new( "".to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(content.to_string()), name: None, tool_calls: None, tool_call_id: None, }], ) .max_tokens(max_len); let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?; Ok(cr.choices[0].message.content.clone()) }
將Retriever和推理服務(wù)集成
pub async fn answer(question: &str, max_len: i64) -> Result> { let retriver = retriever(question, 1).await?; let mut context = "".to_string(); for sp in retriver.result { let payload = sp.payload; let product = payload.get("product").unwrap().to_string(); let title = payload.get("title").unwrap().to_string(); let content = payload.get("content").unwrap().to_string(); context.push_str(&product); context.push_str(&title); context.push_str(&content); } let prompt = format!( "你是一個(gè)云技術(shù)專家, 使用以下檢索到的Context回答問題。用中文回答問題。 Question: {} Context: {} ", question, context ); log::info!("{}", prompt); let req = ChatCompletionRequest::new( "".to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(prompt), name: None, tool_calls: None, tool_call_id: None, }], ) .max_tokens(max_len); let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?; Ok(cr.choices[0].message.content.clone()) }
后記
完整工程地址[embedding_server]https://github.com/jiashiwen/embedding_server
后續(xù)工程問題,多卡推理,多機(jī)推理,推理加速
資源對(duì)比
GPU 型號(hào)
|=========================================+========================+======================| | 0 NVIDIA A30 Off | 00000000:00:07.0 Off | 0 | | N/A 30C P0 29W / 165W | 0MiB / 24576MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+
Embedding 資源
m3e-large
vllm
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 822789 C ...iprojects/rag_demo/.venv/bin/python 1550MiB | +-----------------------------------------------------------------------------------------+
candle
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 823261 C target/debug/embedding_server 1484MiB | +-----------------------------------------------------------------------------------------+
推理資源
Qwen1.5-1.8B-Chat
vllm
|=========================================================================================| | 0 N/A N/A 822437 C /usr/bin/python3 20440MiB | +-----------------------------------------------------------------------------------------+
mistral.rs
|=========================================================================================| | 0 N/A N/A 822174 C target/debug/mistralrs-server 22134MiB | +-----------------------------------------------------------------------------------------+
Qwen2-7B
vllm 現(xiàn)存溢出
[rank0]: OutOfMemoryError: CUDA out of memory. Tried to allocate 9.25 GiB. GPU
mistral.rs
|=========================================================================================| | 0 N/A N/A 656923 C target/debug/mistralrs-server 22006MiB | +-----------------------------------------------------------------------------------------+
從實(shí)際情況來看,Embedding 模型再資源占用情況 rust candle框架使用顯存略小些;推理模型Qwen1.5-1.8B-Chat,vllm 資源占用略小。Qwen2-7B vllm直接顯存溢出。
坑
大部分框架中使用 hf-hub 采用同步調(diào)用,不支持境內(nèi)的mirror。動(dòng)手改造
src/api/tokio.rs
impl ApiBuilder { /// Set endpoint example 'https://hf-mirror.com' pub fn with_endpoint(mut self, endpoint: &str) -> Self { self.endpoint = endpoint.to_string(); self } } 審核編輯 黃宇
-
AI
+關(guān)注
關(guān)注
87文章
29707瀏覽量
268022 -
Rust
+關(guān)注
關(guān)注
1文章
228瀏覽量
6541 -
LLM
+關(guān)注
關(guān)注
0文章
256瀏覽量
297
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論