LangChain.js + OpenAI API(GPT-4)で、ベクターストアを使わずに擬似RAGを作ってみた
この記事は「LangChain Advent Calendar 2023」25日目の記事です。ベクターストア無しでRAG(retrieval augmented generation)を作成する方法について紹介しています。具体的には、Embeddingを使わずにベクトルデータを生成し、WordPressのAPIを用いて投稿を検索する方法を解説しています。また、RemoteLangChainRetrieverをカスタマイ
目次
この記事は「LangChain Advent Calendar 2023」25日目の記事です。最終日ですので、かなりマニアックなものを投入してみようかなと思います。
ベクターストア無しでRAGは作れるのか?
質問に対する回答を、LLMが学習していないデータから作るRAG(retrieval augmented generation)。これを作るには、一般的にはEmbeddingを利用したベクトルデータの生成と、それを検索可能な形で保存するベクターストアが必要です。
個人的な興味から、このベクターストアを持たずにRAGまたはそれに準ずるものが作れないかに挑戦してみました。
前提: Chat GPT Plusに契約して、GPT-4を利用できるようにする
試してみたところ、ベクトルデータを使わない場合トークン数がかなり大きくなり、すぐに上限に引っかかります。そのため、ClaudeやGPT-4など比較的上限値が大きいモデルを使うことになります。
const model = new OpenAI({
openAIApiKey: process.env.openaiApiKey,
modelName: "gpt-4",
})
今回はGPT-4を利用しました。
RemoteLangChainRetriever
でベクトル検索の代わりになる仕組みを作る
ベクターストアは主に検索部分で利用します。今回はその代替手段として、RemoteLangChainRetriever
を採用しました。このRetriever
を利用することで、回答文生成で利用するDocumentを任意のPOST APIから取得できます。
ということで今回はWordPressのAPIで投稿を検索する形にしてみました。
const app = new Hono()
app.post('/retrieve', async c => {
const { query } = await c.req.json()
const response = await fetch(`https://YOUR_WP_SITE_URL/wp-json/wp/v2/posts?search=${query.output}&limit=2`)
const posts = await response.json()
return c.json({
data: posts.map((post: any) => ({
content: post.content.rendered,
metadata: {
id: post.id,
}
}))
})
})
トークン数対策で、TextSplitterを追加してみる
記事本文を送ると、すぐにトークン上限に到達してしまいました。コードを見る限り、RemoteLangChainRetriever
はTextSplitterを用いずにDocumentを作成していましたので、あまり大きいデータを扱うのには向いていなかったかもしれません。今回はそれでも使いたかったので、RemoteLangChainRetriever
をextends
して、Documentを作成する処理にTextSplitterを差し込む形でoverwriteしました。コードが@ts-expect-error
していますが、これはprocessJsonResponse
は本来非同期関数として定義されていないことが原因です。実行コードを見る限り、Promise
を使っても動きそうでしたので、強引に書き換えてみました。
class CustomRemoteLangChainRetriever extends RemoteLangChainRetriever {
// @ts-expect-error
async processJsonResponse(json: RemoteRetrieverValues): Promise<Document[]> {
const splitter = RecursiveCharacterTextSplitter.fromLanguage("html", {
chunkSize: 1000,
chunkOverlap: 0,
});
const transformer = new HtmlToTextTransformer();
const sequence = splitter.pipe(transformer);
const documents: Array<Document> = []
for await (const r of json[this.responseKey]) {
documents.push(new Document({
pageContent: r[this.pageContentKey],
metadata: r[this.metadataKey],
}))
}
const newDocuments = await sequence.invoke(documents);
return newDocuments
}
}
Retrieverを作ったので、呼び出すAPIのURLなどを指定します。
const retriever = new CustomRemoteLangChainRetriever({
url: "http://0.0.0.0:3000/remote",
auth: false, // Replace with your own auth.
inputKey: "query",
pageContentKey: "content",
responseKey: "data",
verbose: true
});
キーワード検索のため、検索キーワードを抽出するChainを作る
検索方法がEmbeddingを使ったベクトル検索ではなく、キーワードベースの検索に変わります。そのため、入力された質問文からキーワードを抽出する処理を作りましょう。
const reviewTemplate = `あなたは入力された文章から、検索キーワードを抽出する作業を行っています。入力されたメッセージから、キーワードをスペース区切りのテキストで出力してください。
出力するテキストには、「教えて」などの動詞は含めないでください。
input: {query}
{format_instructions}
`;
const outputParser = StructuredOutputParser.fromNamesAndDescriptions({
output: "Suggested search keyword",
})
const reviewPromptTemplate = new PromptTemplate({
template: reviewTemplate,
inputVariables: ["query"],
partialVariables: {
format_instructions: outputParser.getFormatInstructions()
},
outputParser: outputParser,
});
const queryChain = new LLMChain({
llm: model,
prompt: reviewPromptTemplate,
});
Chainにしておくことで、Retrieverとつなげやすくしておきます。
RetrievalQAChainで回答文を作るChainを追加する
先ほど作成したRetrieverを利用して回答文を作るステップを用意します。この辺りは通常のQAChainを作るケースとかなり似ています。
const template = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:";
const QA_CHAIN_PROMPT = new PromptTemplate({
inputVariables: ["context", "question"],
template,
});
const qaChain = new RetrievalQAChain({
combineDocumentsChain: loadQAStuffChain(model, { prompt: QA_CHAIN_PROMPT }),
retriever,
returnSourceDocuments: false
})
SimpleSequentialChainで順番にChainを実行する
今回は「検索ワードの抽出」後に「検索と回答文の生成」を行う必要があります。そのため、SimpleSequentialChain
を利用して順番にChainを実行するように構成しました。
最後に質問文を設定して実行します。
const overallChain = new SimpleSequentialChain({
chains: [queryChain, qaChain],
});
// Call the chain with a query.
const res = await overallChain.run( "momento SDKの使い方を教えて", {
callbacks: []
});
Chainを実行してみた結果
実行してみたところ、APIエラーなどが起きることなく回答文を得ることができました。ただし日本語での質問ですが、英語で回答が来ています。これはRetrievalQAChain
のプロンプトが英語だからかもしれません。
"The article discusses how to use the Momento SDK with Cloudflare Workers (Hono). It outlines details about how to introduce the Momento SDK into Cloudflare Workers, including the need to install a polyfill for XMLHttpRequest, how to read in the API key as a string, and the differences in the way the client instances are generated. The article also shares a structured code sample for a full-picture view. The writer goes on to discuss using Momento under Hono + Wrangler and mentions the Limitations of using the Node SDK. The latter part of the article delves into diving into the vector store function, creating an index from the GUI, obtaining the API key, the installation and set up of the SDK. It discusses using the SDK to create an index, upsert data and perform searches."
ちなみにGoogle翻訳にかけるとこんな文章になります。唐突にCloudflareが出てくるあたり、自分の記事を参考にして回答文を作った感があるなーと思います。この辺りは検索APIの精度に左右される部分も強いかもしれません。
「この記事では、Cloudflare Workers (Hono) で Momento SDK を使用する方法について説明します。この記事では、XMLHttpRequest のポリフィルをインストールする必要性、API キーを 文字列、およびクライアント インスタンスの生成方法の違いについても説明します。この記事では、全体像を示す構造化されたコード サンプルも共有しています。ライターは続けて Hono + Wrangler での Momento の使用について説明し、Node SDK の使用の制限についても言及しています 記事の後半では、ベクター ストア関数の詳細、GUI からのインデックスの作成、API キーの取得、SDK のインストールとセットアップについて詳しく説明しており、SDK を使用してインデックスを作成し、データを更新/挿入する方法について説明します。 検索を実行します。」
また、文章によっては回答が生成できないケースも見受けられました。
"申し訳ありませんが、質問が不適切な形式で提出されたため、正確な回答を提供できません。質問を明確にしていただけますと幸いです。どうもありがとう!"
やってみての感想
「検索まではなんとかなるが、文章生成がとにかくつらい」というのがやってみての感想です。記事の文章をLLMに平文で投げる形になるため、トークンの消費量が大きくなりがちで、GPT-3.5などではかなりの確率でAPIエラーが発生します。
結論としては「理論上は可能。ただし精度や安定性を求めるならば、おとなしくベクターストアを使おう」といったところでしょうか。逆に考えると、トークンさえ抑えることができれば、この方法も使い道が出てくるかもしれません。
全コード
参考までに、今回試したコード全文です。
import { Hono } from 'hono'
import { Bedrock } from "langchain/llms/bedrock";
import { OpenAI } from "langchain/llms/openai";
import { RetrievalQAChain, loadQAStuffChain } from "langchain/chains";
import { RemoteLangChainRetriever, RemoteRetrieverValues } from "langchain/retrievers/remote";
import { SimpleSequentialChain, LLMChain } from "langchain/chains";
import { PromptTemplate } from 'langchain/prompts';
import { StructuredOutputParser } from "langchain/output_parsers";
import { Document } from 'langchain/document';
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { HtmlToTextTransformer } from "langchain/document_transformers/html_to_text";
const app = new Hono()
app.post('/remote', async c => {
const { query } = await c.req.json()
console.log({query})
const response = await fetch(`https://YOUR_WP_SITE_URL/wp-json/wp/v2/posts?search=${query.output}&limit=2`)
const posts = await response.json()
return c.json({
data: posts.map((post: any) => ({
content: post.content.rendered,
metadata: {
id: post.id,
}
}))
})
})
class CustomRemoteLangChainRetriever extends RemoteLangChainRetriever {
// @ts-expect-error
async processJsonResponse(json: RemoteRetrieverValues): Promise<Document[]> {
const splitter = RecursiveCharacterTextSplitter.fromLanguage("html", {
chunkSize: 1000,
chunkOverlap: 0,
});
const transformer = new HtmlToTextTransformer();
const sequence = splitter.pipe(transformer);
const documents: Array<Document> = []
for await (const r of json[this.responseKey]) {
documents.push(new Document({
pageContent: r[this.pageContentKey],
metadata: r[this.metadataKey],
}))
}
const newDocuments = await sequence.invoke(documents);
return newDocuments
}
}
const bedrockApp = new Hono()
bedrockApp.get('/chat', async c => {
const model = new OpenAI({
openAIApiKey: env.openai.apiKey,
modelName: "gpt-4",
})
// Initialize the remote retriever.
const retriever = new CustomRemoteLangChainRetriever({
url: "http://0.0.0.0:3000/remote",
auth: false, // Replace with your own auth.
inputKey: "query",
pageContentKey: "content",
responseKey: "data",
verbose: true
});
// Create a chain that uses the OpenAI LLM and remote retriever.
const reviewTemplate = `あなたは入力された文章から、検索キーワードを抽出する作業を行っています。入力されたメッセージから、キーワードをスペース区切りのテキストで出力してください。
出力するテキストには、「教えて」などの動詞は含めないでください。
input: {query}
{format_instructions}
`;
const outputParser = StructuredOutputParser.fromNamesAndDescriptions({
output: "Suggested search keyword",
})
const reviewPromptTemplate = new PromptTemplate({
template: reviewTemplate,
inputVariables: ["query"],
partialVariables: {
format_instructions: outputParser.getFormatInstructions()
},
outputParser: outputParser,
});
const queryChain = new LLMChain({
llm: model,
prompt: reviewPromptTemplate,
});
const template = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:";
const QA_CHAIN_PROMPT = new PromptTemplate({
inputVariables: ["context", "question"],
template,
});
const qaChain = new RetrievalQAChain({
combineDocumentsChain: loadQAStuffChain(model, { prompt: QA_CHAIN_PROMPT }),
retriever,
returnSourceDocuments: false
})
const overallChain = new SimpleSequentialChain({
chains: [queryChain, qaChain],
});
// Call the chain with a query.
const res = await overallChain.run( "momento SDKの使い方を教えて", {
callbacks: []
});
return c.json(res)
})