import os
import glob
from typing import TypedDict, Annotated, Literal
from dotenv import load_dotenv
from xparse_client import XParseClient
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from langchain_milvus import Milvus
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.chat_models import ChatTongyi
from langchain_core.documents import Document
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
load_dotenv()
# ========== Step 1: 使用 xParse SDK 构建知识库 ==========
client = XParseClient()
def build_knowledge_base():
"""构建知识库"""
print("开始构建知识库...")
docs_dir = "./knowledge_base"
headers_to_split_on = [("#", "header1"), ("##", "header2"), ("###", "header3")]
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=80)
all_chunks = []
for file_path in glob.glob(os.path.join(docs_dir, "*")):
if not os.path.isfile(file_path):
continue
with open(file_path, "rb") as f:
result = client.parse.run(file=f, filename=os.path.basename(file_path))
md_docs = markdown_splitter.split_text(result.markdown)
for doc in md_docs:
doc.metadata["filename"] = os.path.basename(file_path)
chunks = text_splitter.split_documents(md_docs)
all_chunks.extend(chunks)
embedding = DashScopeEmbeddings(model="text-embedding-v4")
Milvus.from_documents(
documents=all_chunks,
embedding=embedding,
collection_name="agentic_rag_docs",
connection_args={"uri": os.getenv("MILVUS_DB_PATH")},
)
print("知识库构建完成!")
# ========== Step 2: 初始化向量数据库和大模型 ==========
embedding = DashScopeEmbeddings(model="text-embedding-v4")
vector_store = Milvus(
embedding_function=embedding,
collection_name="agentic_rag_docs",
connection_args={"uri": os.getenv("MILVUS_DB_PATH")},
)
llm = ChatTongyi(
model="qwen-max",
top_p=0.8,
dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
)
# ========== Step 3: 定义状态结构 ==========
class GraphState(TypedDict):
"""工作流状态"""
messages: Annotated[list[BaseMessage], add_messages]
question: str # 原始问题
rewritten_question: str # 重写后的问题
documents: list # 检索到的文档
generation: str # 生成的答案
next: str # 下一步操作
retrieval_count: int # 检索次数
# ========== Step 4: 定义节点函数 ==========
def should_retrieve(state: GraphState) -> GraphState:
"""判断是否需要检索"""
question = state["question"]
prompt = f"""判断以下问题是否需要从知识库中检索信息才能回答。
问题:{question}
如果问题需要特定的文档、数据或知识库信息才能回答,返回 "retrieve"。
如果问题是一般性对话、问候或不需要特定信息的简单问题,返回 "generate"。
只返回 "retrieve" 或 "generate",不要返回其他内容。"""
response = llm.invoke([HumanMessage(content=prompt)])
decision = response.content.strip().lower()
next_step = "retrieve" if "retrieve" in decision else "generate"
return {
**state,
"next": next_step
}
def retrieve(state: GraphState) -> GraphState:
"""检索相关文档"""
question = state.get("rewritten_question") or state["question"]
retrieval_count = state.get("retrieval_count", 0)
docs = vector_store.similarity_search(question, k=5)
documents = []
for doc in docs:
documents.append({
"content": doc.page_content,
"metadata": doc.metadata
})
return {
**state,
"documents": documents,
"retrieval_count": retrieval_count + 1
}
def grade_documents(state: GraphState) -> GraphState:
"""评估检索结果的相关性"""
question = state.get("rewritten_question") or state["question"]
documents = state["documents"]
retrieval_count = state.get("retrieval_count", 0)
if retrieval_count >= 2:
return {
**state,
"next": "generate"
}
if not documents:
return {
**state,
"next": "rewrite"
}
docs_text = "\n\n".join([
f"文档 {i+1}:\n{doc['content'][:300]}..."
for i, doc in enumerate(documents[:3])
])
prompt = f"""评估以下检索到的文档是否与问题相关。
问题:{question}
检索到的文档:
{docs_text}
如果文档与问题高度相关,能够回答问题,返回 "generate"。
如果文档与问题不相关或相关性很低,返回 "rewrite"。
只返回 "generate" 或 "rewrite",不要返回其他内容。"""
response = llm.invoke([HumanMessage(content=prompt)])
decision = response.content.strip().lower()
next_step = "generate" if "generate" in decision else "rewrite"
return {
**state,
"next": next_step
}
def rewrite_question(state: GraphState) -> GraphState:
"""重写问题"""
question = state["question"]
documents = state.get("documents", [])
previous_rewrite = state.get("rewritten_question", "")
if documents:
docs_summary = "\n".join([
f"- {doc['content'][:200]}..."
for doc in documents[:2]
])
prompt = f"""原始问题:{question}
当前检索到的文档摘要:
{docs_summary}
这些文档与问题不够相关。请重写问题,使其能够更好地匹配知识库中的内容。
重写时应该:
1. 保持问题的核心意图
2. 使用更具体的关键词
3. 考虑知识库可能使用的术语
只返回重写后的问题,不要返回其他内容。"""
else:
prompt = f"""原始问题:{question}
请重写这个问题,使其更具体、更清晰,便于在知识库中检索相关信息。
重写时应该:
1. 保持问题的核心意图
2. 使用更具体的关键词
3. 考虑知识库可能使用的术语
只返回重写后的问题,不要返回其他内容。"""
response = llm.invoke([HumanMessage(content=prompt)])
rewritten = response.content.strip()
return {
**state,
"rewritten_question": rewritten
}
def generate(state: GraphState) -> GraphState:
"""生成答案"""
question = state["question"]
documents = state.get("documents", [])
if documents:
context = "\n\n".join([
f"文档来源:{doc['metadata'].get('filename', '未知')}\n内容:{doc['content']}"
for i, doc in enumerate(documents)
])
prompt = f"""基于以下文档内容回答用户问题。
文档内容:
{context}
用户问题:{question}
请基于文档内容回答问题。如果文档中没有相关信息,请说明。
在回答中引用具体的文档来源。"""
else:
prompt = f"""回答以下问题:{question}"""
response = llm.invoke([HumanMessage(content=prompt)])
return {
**state,
"generation": response.content
}
# ========== Step 5: 构建 LangGraph 工作流 ==========
workflow = StateGraph(GraphState)
workflow.add_node("should_retrieve", should_retrieve)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("rewrite_question", rewrite_question)
workflow.add_node("generate", generate)
workflow.set_entry_point("should_retrieve")
workflow.add_conditional_edges(
"should_retrieve",
lambda state: state.get("next", "generate"),
{
"retrieve": "retrieve",
"generate": "generate"
}
)
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
lambda state: state.get("next", "generate"),
{
"generate": "generate",
"rewrite": "rewrite_question"
}
)
workflow.add_edge("rewrite_question", "retrieve")
workflow.add_edge("generate", END)
app = workflow.compile()
# ========== Step 6: 使用示例 ==========
def ask_question(question: str) -> str:
"""提问并获取答案"""
initial_state = {
"messages": [HumanMessage(content=question)],
"question": question,
"rewritten_question": "",
"documents": [],
"generation": "",
"next": "",
"retrieval_count": 0
}
result = app.invoke(initial_state)
return result["generation"]
if __name__ == "__main__":
build_knowledge_base()
questions = [
"如何配置数据库连接池?",
"产品的定价策略是什么?",
"你好"
]
for question in questions:
print("=" * 60)
print(f"问题:{question}")
print("=" * 60)
answer = ask_question(question)
print(f"回答:{answer}\n")