diff --git a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py index f1e0c6c18..e8eb663f6 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py @@ -78,6 +78,7 @@ class BasePromptConfig: text2gql_graph_schema: str = "" gremlin_generate_prompt: str = "" doc_input_text: str = "" + graph_extract_split_type: str = "document" _language_generated: str = "" generate_extract_prompt_template: str = "" @@ -136,6 +137,7 @@ def to_literal(val): "keywords_extract_prompt": to_literal(self.keywords_extract_prompt), "gremlin_generate_prompt": to_literal(self.gremlin_generate_prompt), "doc_input_text": to_literal(self.doc_input_text), + "graph_extract_split_type": to_literal(self.graph_extract_split_type), "_language_generated": str(self.llm_settings.language).lower().strip(), "generate_extract_prompt_template": to_literal(self.generate_extract_prompt_template), } diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py index 5dbd87c46..addcbc19e 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py @@ -44,12 +44,17 @@ ) -def store_prompt(doc, schema, example_prompt): - # update env variables: doc, schema and example_prompt - if prompt.doc_input_text != doc or prompt.graph_schema != schema or prompt.extract_graph_prompt != example_prompt: +def store_prompt(doc, schema, example_prompt, graph_extract_split_type="document"): + if ( + prompt.doc_input_text != doc + or prompt.graph_schema != schema + or prompt.extract_graph_prompt != example_prompt + or prompt.graph_extract_split_type != graph_extract_split_type + ): prompt.doc_input_text = doc prompt.graph_schema = schema prompt.extract_graph_prompt = example_prompt + prompt.graph_extract_split_type = graph_extract_split_type prompt.update_yaml_file() @@ -270,6 +275,12 @@ def create_vector_graph_block(): graph_data_btn0 = gr.Button("Clear Graph Data", size="sm") vector_import_bt = gr.Button("Import into Vector", variant="primary") + graph_split_type = gr.Dropdown( + choices=["document", "paragraph", "sentence"], + value=prompt.graph_extract_split_type, + label="Graph Extraction Split Type", + info=("document keeps the current behavior; paragraph/sentence split long docs before extraction."), + ) graph_extract_bt = gr.Button("Extract Graph Data (1)", variant="primary") graph_loading_bt = gr.Button("Load into GraphDB (2)", interactive=True) graph_index_rebuild_bt = gr.Button("Update Vid Embedding") @@ -300,48 +311,54 @@ def create_vector_graph_block(): vector_index_btn0.click(get_vector_index_info, outputs=out).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) vector_index_btn1.click(clean_vector_index).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) vector_import_bt.click(build_vector_index, inputs=[input_file, input_text], outputs=out).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) graph_index_btn0.click(get_graph_index_info, outputs=out).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) graph_index_btn1.click(clean_all_graph_index).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) graph_data_btn0.click(clean_all_graph_data).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) graph_index_rebuild_bt.click(update_vid_embedding, outputs=out).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) # origin_out = gr.Textbox(visible=False) graph_extract_bt.click( extract_graph, - inputs=[input_file, input_text, input_schema, info_extract_template], + inputs=[ + input_file, + input_text, + input_schema, + info_extract_template, + graph_split_type, + ], outputs=[out], ).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) graph_loading_bt.click(import_graph_data, inputs=[out, input_schema], outputs=[out]).then( update_vid_embedding ).then( store_prompt, - inputs=[input_text, input_schema, info_extract_template], + inputs=[input_text, input_schema, info_extract_template, graph_split_type], ) # TODO: we should store the examples after the user changed them. @@ -355,6 +372,7 @@ def create_vector_graph_block(): input_text, input_schema, info_extract_template, + graph_split_type, ], # TODO: Store the updated examples ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py index 0057f2b71..13629e2be 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -21,6 +21,10 @@ from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode from hugegraph_llm.nodes.llm_node.extract_info import ExtractNode +from hugegraph_llm.operators.document_op.chunk_split import ( + SPLIT_TYPE_DOCUMENT, + VALID_SPLIT_TYPES, +) from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.utils.log import log @@ -37,22 +41,43 @@ def prepare( texts, example_prompt, extract_type, + split_type=SPLIT_TYPE_DOCUMENT, language="zh", **kwargs, ): # prepare input data prepared_input.texts = texts prepared_input.language = language - prepared_input.split_type = "document" + if split_type not in VALID_SPLIT_TYPES: + raise ValueError("split_type must be document, paragraph, or sentence") + + prepared_input.split_type = split_type prepared_input.example_prompt = example_prompt prepared_input.schema = schema prepared_input.extract_type = extract_type - def build_flow(self, schema, texts, example_prompt, extract_type, language="zh", **kwargs): + def build_flow( + self, + schema, + texts, + example_prompt, + extract_type, + split_type=SPLIT_TYPE_DOCUMENT, + language="zh", + **kwargs, + ): pipeline = GPipeline() prepared_input = WkFlowInput() # prepare input data - self.prepare(prepared_input, schema, texts, example_prompt, extract_type, language) + self.prepare( + prepared_input, + schema, + texts, + example_prompt, + extract_type, + split_type, + language, + ) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") @@ -70,6 +95,8 @@ def post_deal(self, pipeline=None, **kwargs): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() vertices = res.get("vertices", []) edges = res.get("edges", []) + chunk_count = len(res.get("chunks", [])) + log.info("Graph extraction chunk_count: %s", chunk_count) if not vertices and not edges: log.info("Please check the schema.(The schema may not match the Doc)") return json.dumps( diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py index a22e4de88..83a1d4bbe 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py @@ -16,6 +16,7 @@ # under the License. +import re from typing import Any, Dict, List, Literal, Optional, Union from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -26,6 +27,16 @@ SPLIT_TYPE_DOCUMENT = "document" SPLIT_TYPE_PARAGRAPH = "paragraph" SPLIT_TYPE_SENTENCE = "sentence" +VALID_SPLIT_TYPES = ( + SPLIT_TYPE_DOCUMENT, + SPLIT_TYPE_PARAGRAPH, + SPLIT_TYPE_SENTENCE, +) + + +def _split_sentence_boundaries(text: str) -> list[str]: + sentence_pattern = re.compile(r"[^.!?\u3002\uff01\uff1f\uff1b;]+[.!?\u3002\uff01\uff1f\uff1b;]*") + return [sentence.strip() for sentence in sentence_pattern.findall(text) if sentence.strip()] class ChunkSplit: @@ -56,8 +67,8 @@ def _get_text_splitter(self, split_type: str): chunk_size=500, chunk_overlap=30, separators=self.separators ).split_text if split_type == SPLIT_TYPE_SENTENCE: - return RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=0, separators=self.separators).split_text - raise ValueError("Type must be paragraph, sentence, html or markdown") + return _split_sentence_boundaries + raise ValueError("split_type must be document, paragraph, or sentence") def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]: all_chunks = [] diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 423526ea4..78e9030d4 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -24,6 +24,10 @@ from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.scheduler import SchedulerSingleton +from hugegraph_llm.operators.document_op.chunk_split import ( + SPLIT_TYPE_DOCUMENT, + VALID_SPLIT_TYPES, +) from ..config import huge_settings from .hugegraph_utils import clean_hg_data @@ -77,14 +81,28 @@ def clean_all_graph_data(): gr.Info("Clear graph data successfully!") -def extract_graph(input_file, input_text, schema, example_prompt) -> str: +def extract_graph( + input_file, + input_text, + schema, + example_prompt, + split_type=SPLIT_TYPE_DOCUMENT, +) -> str: texts = read_documents(input_file, input_text) scheduler = SchedulerSingleton.get_instance() if not schema: return "ERROR: please input with correct schema/format." - + if split_type not in VALID_SPLIT_TYPES: + raise gr.Error("split_type must be document, paragraph, or sentence") try: - return scheduler.schedule_flow(FlowName.GRAPH_EXTRACT, schema, texts, example_prompt, "property_graph") + return scheduler.schedule_flow( + FlowName.GRAPH_EXTRACT, + schema, + texts, + example_prompt, + "property_graph", + split_type=split_type, + ) except Exception as e: # pylint: disable=broad-exception-caught log.error(e) raise gr.Error(str(e)) diff --git a/hugegraph-llm/src/tests/document/test_graph_extract_configurable_split.py b/hugegraph-llm/src/tests/document/test_graph_extract_configurable_split.py new file mode 100644 index 000000000..4e5078bb9 --- /dev/null +++ b/hugegraph-llm/src/tests/document/test_graph_extract_configurable_split.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from types import SimpleNamespace + +import gradio as gr +import pytest + +from hugegraph_llm.config.models import base_prompt_config +from hugegraph_llm.config.models.base_prompt_config import BasePromptConfig +from hugegraph_llm.flows import FlowName +from hugegraph_llm.flows.graph_extract import GraphExtractFlow +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit +from hugegraph_llm.state.ai_state import WkFlowInput +from hugegraph_llm.utils import graph_index_utils + + +class DummyScheduler: + def __init__(self): + self.calls = [] + self.kwargs = [] + + def schedule_flow(self, *args, **kwargs): + self.calls.append(args) + self.kwargs.append(kwargs) + return "scheduled" + + +class DummyPipelineState: + def to_json(self): + return { + "chunks": ["chunk one", "chunk two"], + "vertices": [{"id": "person:alice"}], + "edges": [], + } + + +class DummyPipeline: + def getGParamWithNoEmpty(self, name): + assert name == "wkflow_state" + return DummyPipelineState() + + +class CapturePipeline: + def __init__(self): + self.params = {} + + def createGParam(self, value, name): + self.params[name] = value + + def registerGElement(self, *args): + return None + + +def test_graph_extract_prepare_preserves_default_document_split_type(): + prepared_input = WkFlowInput() + + GraphExtractFlow().prepare( + prepared_input, + "{}", + ["first document"], + "extract prompt", + "property_graph", + ) + + assert prepared_input.split_type == "document" + + +def test_graph_extract_prepare_accepts_non_default_split_type(): + prepared_input = WkFlowInput() + + GraphExtractFlow().prepare( + prepared_input, + "{}", + ["first paragraph\n\nsecond paragraph"], + "extract prompt", + "property_graph", + "paragraph", + ) + + assert prepared_input.split_type == "paragraph" + + +def test_graph_extract_prepare_rejects_invalid_split_type(): + prepared_input = WkFlowInput() + + with pytest.raises(ValueError, match="split_type must be document"): + GraphExtractFlow().prepare( + prepared_input, + "{}", + ["first document"], + "extract prompt", + "property_graph", + "invalid", + ) + + +def test_graph_extract_build_flow_passes_non_default_split_type_to_workflow_input( + monkeypatch, +): + monkeypatch.setattr( + "hugegraph_llm.flows.graph_extract.GPipeline", + CapturePipeline, + ) + + pipeline = GraphExtractFlow().build_flow( + "{}", + ["first paragraph\n\nsecond paragraph"], + "extract prompt", + "property_graph", + "paragraph", + ) + + assert pipeline.params["wkflow_input"].split_type == "paragraph" + + +def test_chunk_split_non_default_types_produce_multiple_chunks(): + paragraph_text = ("Alpha " * 120) + "\n\n" + ("Beta " * 120) + sentence_text = "Alpha sentence. Beta sentence. Gamma sentence. Delta sentence. Epsilon sentence. Zeta sentence." + + paragraph_chunks = ChunkSplit(paragraph_text, "paragraph", "en").run(None)["chunks"] + sentence_chunks = ChunkSplit(sentence_text, "sentence", "en").run(None)["chunks"] + + assert len(paragraph_chunks) > 1 + assert len(sentence_chunks) > 1 + + +def test_extract_graph_helper_forwards_selected_split_type(monkeypatch): + scheduler = DummyScheduler() + monkeypatch.setattr( + graph_index_utils, + "read_documents", + lambda input_file, input_text: ["graph extraction text"], + ) + monkeypatch.setattr( + graph_index_utils.SchedulerSingleton, + "get_instance", + lambda: scheduler, + ) + + result = graph_index_utils.extract_graph( + [], + "", + "{}", + "extract prompt", + "sentence", + ) + + assert result == "scheduled" + assert scheduler.calls == [ + ( + FlowName.GRAPH_EXTRACT, + "{}", + ["graph extraction text"], + "extract prompt", + "property_graph", + ) + ] + assert scheduler.kwargs == [{"split_type": "sentence"}] + + +def test_extract_graph_helper_rejects_invalid_split_type(monkeypatch): + monkeypatch.setattr( + graph_index_utils, + "read_documents", + lambda input_file, input_text: ["graph extraction text"], + ) + monkeypatch.setattr( + graph_index_utils.SchedulerSingleton, + "get_instance", + lambda: DummyScheduler(), + ) + + with pytest.raises(gr.Error, match="split_type must be document"): + graph_index_utils.extract_graph( + [], + "", + "{}", + "extract prompt", + "invalid", + ) + + +def test_graph_extract_post_deal_logs_chunk_count(monkeypatch): + log_calls = [] + monkeypatch.setattr( + "hugegraph_llm.flows.graph_extract.log.info", + lambda message, *args: log_calls.append((message, args)), + ) + + result = GraphExtractFlow().post_deal(DummyPipeline()) + result_data = json.loads(result) + + assert result_data["vertices"] == [{"id": "person:alice"}] + assert any(message == "Graph extraction chunk_count: %s" and args == (2,) for message, args in log_calls) + + +def test_sentence_split_returns_punctuation_delimited_sentences(): + chunks = ChunkSplit( + "Alpha sentence one. Beta sentence two? Gamma sentence three!", + "sentence", + "en", + ).run(None)["chunks"] + + assert chunks == [ + "Alpha sentence one.", + "Beta sentence two?", + "Gamma sentence three!", + ] + + +def test_prompt_config_round_trips_graph_extract_split_type(monkeypatch, tmp_path): + prompt_path = tmp_path / "config_prompt.yaml" + monkeypatch.setattr(base_prompt_config, "yaml_file_path", str(prompt_path)) + + config = BasePromptConfig() + config.llm_settings = SimpleNamespace(language="en") + config.graph_extract_split_type = "sentence" + config.save_to_yaml() + + reloaded = BasePromptConfig() + reloaded.llm_settings = SimpleNamespace(language="en") + reloaded.ensure_yaml_file_exists() + + assert reloaded.graph_extract_split_type == "sentence"