1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
| from dataclasses import dataclass from llama_index.core import PromptTemplate, Settings from llama_index.llms.ollama import Ollama from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.core.types import BaseOutputParser import json from pathlib import Path from llama_index.readers.file import PyMuPDFReader from llama_index.core import VectorStoreIndex, SummaryIndex from llama_index.core.node_parser import SentenceSplitter from llama_index.core.query_engine import CustomQueryEngine from llama_index.core.response_synthesizers import TreeSummarize from llama_index.core.base.response.schema import Response from pydantic import Field
Settings.embed_model = OllamaEmbedding( model_name="nomic-embed-text:latest", base_url="http://localhost:11434", request_timeout=360.0 )
Settings.llm = Ollama( model="qwen2.5:7b", base_url="http://localhost:11434", request_timeout=360.0, additional_kwargs={"temperature": 0.0, "num_ctx": 2048}, )
loader = PyMuPDFReader() pdf_path = Path("D:\\LlamaIndex\\LLmaRouter\\data\\flag.txt") if not pdf_path.exists(): raise FileNotFoundError(f"文件不存在:{pdf_path.absolute()}")
documents = loader.load(file_path=str(pdf_path)) splitter = SentenceSplitter(chunk_size=1024)
vector_index = VectorStoreIndex.from_documents( documents, transformations=[splitter], embed_model=Settings.embed_model ) summary_index = SummaryIndex.from_documents( documents, transformations=[splitter] )
@dataclass class Answer: choice: int reason: str
route_choices = [ 'Useful for specific/factual questions (e.g., "What is Llama2\'s context window?") → use VectorStoreIndex', 'Useful for summary/general questions (e.g., "Summarize Llama2\'s key features") → use SummaryIndex', ]
def format_choices(choices): choices_str = "\n\n".join([f"{idx+1}. {c}" for idx, c in enumerate(choices)]) print(f"路由选项:{choices_str}") return choices_str
JSON_FORMAT_STR = """ The output should be formatted as a JSON instance that conforms to the JSON schema below. Do NOT add any extra text, explanation, or comments—only output the JSON string.
JSON Schema: { "type": "object", "properties": { "choice": {"type": "integer"}, "reason": {"type": "string"} }, "required": ["choice", "reason"], "additionalProperties": false } """
class RouterOutputParser(BaseOutputParser): def parse(self, output: str) -> Answer: try: output_dict = json.loads(output.strip()) print(f"LLM路由原始输出:{output_dict}") if not isinstance(output_dict.get("choice"), int): raise ValueError(f"choice必须是整数,实际是{type(output_dict.get('choice'))}") if not isinstance(output_dict.get("reason"), str): raise ValueError(f"reason必须是字符串,实际是{type(output_dict.get('reason'))}") choice_num = output_dict["choice"] if choice_num < 1 or choice_num > len(route_choices): raise ValueError(f"选择的编号{choice_num}超出选项范围(1-{len(route_choices)})") return Answer(choice=choice_num, reason=output_dict["reason"]) except (json.JSONDecodeError, KeyError, ValueError) as e: raise RuntimeError(f"解析LLM输出失败:{str(e)},原始响应:{output}")
class CustomRouterQueryEngine(CustomQueryEngine): vector_index: VectorStoreIndex = Field(description="向量索引,用于事实性查询") summary_index: SummaryIndex = Field(description="摘要索引,用于总结性查询") llm: Ollama = Field(description="Ollama LLM实例") parser: RouterOutputParser = Field(description="路由输出解析器") route_choices: list = Field(description="路由选项列表") def custom_query(self, query_str: str) -> Response: try: formatted_choices = format_choices(self.route_choices) router_prompt = PromptTemplate( f""" 你的任务是根据用户问题类型,选择最匹配的索引,并严格按照指定JSON格式返回结果。 【用户问题】:{{query_str}} 【可选索引】: {formatted_choices} {JSON_FORMAT_STR} """ ) filled_prompt = router_prompt.format(query_str=query_str) llm_response = self.llm.complete(filled_prompt) route_answer = self.parser.parse(llm_response.text.strip()) print(f"\n=== 路由决策 ===") print(f"选中索引编号:{route_answer.choice}") print(f"路由理由:{route_answer.reason}") if route_answer.choice == 1: query_engine = self.vector_index.as_query_engine(llm=Settings.llm) used_index = "VectorStoreIndex(事实查询)" elif route_answer.choice == 2: summarizer = TreeSummarize(llm=Settings.llm) query_engine = self.summary_index.as_query_engine( response_synthesizer=summarizer, llm=Settings.llm ) used_index = "SummaryIndex(总结查询)" else: raise ValueError(f"无效的索引编号:{route_answer.choice}") final_response = query_engine.query(query_str) print(f"\n=== 查询执行 ===") print(f"使用索引:{used_index}") return Response( response=str(final_response), metadata={ "route_choice": route_answer.choice, "route_reason": route_answer.reason, "used_index": used_index, "error": False } ) except Exception as e: error_msg = f"路由查询失败:{str(e)}" print(error_msg) return Response( response=error_msg, metadata={"error": True, "error_detail": str(e)} )
if __name__ == "__main__": router_parser = RouterOutputParser() router_query_engine = CustomRouterQueryEngine( vector_index=vector_index, summary_index=summary_index, llm=Settings.llm, parser=router_parser, route_choices=route_choices ) test_queries = [ "28.7Blog其网址域名是多少?他的大学是什么专业?目前他学习那个方向?", ] for idx, test_query in enumerate(test_queries, 1): print(f"\n===================== 测试查询 {idx} =====================") print(f"查询问题:{test_query}") response = router_query_engine.query(test_query) print(f"\n=== 最终回答 ===") print(response.response) print(f"\n=== 路由元信息 ===") print(f"选中索引:{response.metadata.get('used_index')}") print(f"路由理由:{response.metadata.get('route_reason')}") print(f"是否出错:{response.metadata.get('error')}")
|