Skip to content

Commit 1cafc02

Browse files
refactor: refact split_chunks
1 parent a2b9459 commit 1cafc02

4 files changed

Lines changed: 49 additions & 34 deletions

File tree

graphgen/graphgen.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Dict, cast
66

77
import gradio as gr
8-
from tqdm.asyncio import tqdm as tqdm_async
98

109
from graphgen.bases.base_storage import StorageNameSpace
1110
from graphgen.bases.datatypes import Chunk
@@ -18,21 +17,20 @@
1817
TraverseStrategy,
1918
)
2019
from graphgen.operators import (
20+
chunk_documents,
2121
extract_kg,
2222
generate_cot,
2323
judge_statement,
2424
quiz,
2525
read_files,
2626
search_all,
27-
split_chunks,
2827
traverse_graph_for_aggregated,
2928
traverse_graph_for_atomic,
3029
traverse_graph_for_multi_hop,
3130
)
3231
from graphgen.utils import (
3332
async_to_sync_method,
3433
compute_content_hash,
35-
detect_main_language,
3634
format_generation_results,
3735
logger,
3836
)
@@ -110,7 +108,6 @@ async def insert(self):
110108
"""
111109
insert chunks into the graph
112110
"""
113-
114111
input_file = self.config["read"]["input_file"]
115112

116113
# Step 1: Read files
@@ -138,33 +135,7 @@ async def insert(self):
138135
return
139136
logger.info("[New Docs] inserting %d docs", len(new_docs))
140137

141-
cur_index = 1
142-
doc_number = len(new_docs)
143-
async for doc_key, doc in tqdm_async(
144-
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
145-
):
146-
doc_language = detect_main_language(doc["content"])
147-
text_chunks = split_chunks(
148-
doc["content"],
149-
language=doc_language,
150-
chunk_size=self.config["split"]["chunk_size"],
151-
chunk_overlap=self.config["split"]["chunk_overlap"],
152-
)
153-
154-
chunks = {
155-
compute_content_hash(txt, prefix="chunk-"): {
156-
"content": txt,
157-
"full_doc_id": doc_key,
158-
"length": len(self.tokenizer_instance.encode_string(txt)),
159-
"language": doc_language,
160-
}
161-
for txt in text_chunks
162-
}
163-
inserting_chunks.update(chunks)
164-
165-
if self.progress_bar is not None:
166-
self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
167-
cur_index += 1
138+
inserting_chunks = await chunk_documents(new_docs)
168139

169140
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
170141
list(inserting_chunks.keys())
@@ -246,7 +217,7 @@ async def search(self):
246217
]
247218
)
248219
# TODO: fix insert after search
249-
await self.async_insert()
220+
await self.insert()
250221

251222
@async_to_sync_method
252223
async def quiz(self):

graphgen/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .judge import judge_statement
66
from .quiz import quiz
77
from .read import read_files
8-
from .split import split_chunks
8+
from .split import chunk_documents
99
from .traverse_graph import (
1010
traverse_graph_for_aggregated,
1111
traverse_graph_for_atomic,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .split_chunks import split_chunks
1+
from .split_chunks import chunk_documents

graphgen/operators/split/split_chunks.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from functools import lru_cache
22
from typing import Union
33

4+
from tqdm.asyncio import tqdm as tqdm_async
5+
46
from graphgen.models import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
7+
from graphgen.utils import compute_content_hash, detect_main_language
58

69
_MAPPING = {
710
"en": RecursiveCharacterSplitter,
@@ -26,3 +29,44 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list:
2629
)
2730
splitter = _get_splitter(language, frozenset(kwargs.items()))
2831
return splitter.split_text(text)
32+
33+
34+
async def chunk_documents(
35+
new_docs: dict,
36+
chunk_size: int = 1024,
37+
chunk_overlap: int = 100,
38+
tokenizer_instance=None,
39+
progress_bar=None,
40+
) -> dict:
41+
inserting_chunks = {}
42+
cur_index = 1
43+
doc_number = len(new_docs)
44+
async for doc_key, doc in tqdm_async(
45+
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
46+
):
47+
doc_language = detect_main_language(doc["content"])
48+
text_chunks = split_chunks(
49+
doc["content"],
50+
language=doc_language,
51+
chunk_size=chunk_size,
52+
chunk_overlap=chunk_overlap,
53+
)
54+
55+
chunks = {
56+
compute_content_hash(txt, prefix="chunk-"): {
57+
"content": txt,
58+
"full_doc_id": doc_key,
59+
"length": len(tokenizer_instance.encode_string(txt))
60+
if tokenizer_instance
61+
else len(txt),
62+
"language": doc_language,
63+
}
64+
for txt in text_chunks
65+
}
66+
inserting_chunks.update(chunks)
67+
68+
if progress_bar is not None:
69+
progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
70+
cur_index += 1
71+
72+
return inserting_chunks

0 commit comments

Comments
 (0)