Skip to content

Commit 6a6cb34

Browse files
feat(webui): update webui with splitter config
1 parent 797781d commit 6a6cb34

10 files changed

Lines changed: 97 additions & 55 deletions

File tree

graphgen/configs/__init__.py

Whitespace-only changes.

graphgen/configs/aggregated_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
37
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/configs/atomic_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: atomic # atomic, aggregated, multi_hop, cot
37
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/configs/cot_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: cot # atomic, aggregated, multi_hop, cot
37
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/configs/multi_hop_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
37
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/graphgen.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Tokenizer,
1818
TraverseStrategy,
1919
read_file,
20+
split_chunks,
2021
)
2122

2223
from .operators import (
@@ -32,6 +33,7 @@
3233
from .utils import (
3334
compute_content_hash,
3435
create_event_loop,
36+
detect_main_language,
3537
format_generation_results,
3638
logger,
3739
)
@@ -50,11 +52,6 @@ class GraphGen:
5052
synthesizer_llm_client: OpenAIModel = None
5153
trainee_llm_client: OpenAIModel = None
5254

53-
# text chunking
54-
# TODO: make it configurable
55-
chunk_size: int = 1024
56-
chunk_overlap_size: int = 100
57-
5855
# search
5956
search_config: dict = field(
6057
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
@@ -136,14 +133,22 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
136133
async for doc_key, doc in tqdm_async(
137134
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
138135
):
136+
doc_language = detect_main_language(doc["content"])
137+
text_chunks = split_chunks(
138+
doc["content"],
139+
language=doc_language,
140+
chunk_size=self.config["split"]["chunk_size"],
141+
chunk_overlap=self.config["split"]["chunk_overlap"],
142+
)
143+
139144
chunks = {
140-
compute_content_hash(dp["content"], prefix="chunk-"): {
141-
**dp,
145+
compute_content_hash(txt, prefix="chunk-"): {
146+
"content": txt,
142147
"full_doc_id": doc_key,
148+
"length": len(self.tokenizer_instance.encode_string(txt)),
149+
"language": "en",
143150
}
144-
for dp in self.tokenizer_instance.chunk_by_token_size(
145-
doc["content"], self.chunk_overlap_size, self.chunk_size
146-
)
151+
for txt in text_chunks
147152
}
148153
inserting_chunks.update(chunks)
149154

@@ -171,7 +176,7 @@ async def async_insert(self):
171176
insert chunks into the graph
172177
"""
173178

174-
input_file = self.config["input_file"]
179+
input_file = self.config["read"]["input_file"]
175180
data = read_file(input_file)
176181
inserting_chunks = await self.async_split_chunks(data)
177182

graphgen/models/__init__.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,7 @@
1111
from .search.kg.wiki_search import WikiSearch
1212
from .search.web.bing_search import BingSearch
1313
from .search.web.google_search import GoogleSearch
14+
from .splitter import split_chunks
1415
from .storage.json_storage import JsonKVStorage, JsonListStorage
1516
from .storage.networkx_storage import NetworkXStorage
1617
from .strategy.travserse_strategy import TraverseStrategy
17-
18-
__all__ = [
19-
# llm models
20-
"OpenAIModel",
21-
"TopkTokenModel",
22-
"Token",
23-
"Tokenizer",
24-
# storage models
25-
"NetworkXStorage",
26-
"JsonKVStorage",
27-
"JsonListStorage",
28-
# search models
29-
"WikiSearch",
30-
"GoogleSearch",
31-
"BingSearch",
32-
"UniProtSearch",
33-
# evaluate models
34-
"LengthEvaluator",
35-
"MTLDEvaluator",
36-
"RewardEvaluator",
37-
"UniEvaluator",
38-
# strategy models
39-
"TraverseStrategy",
40-
# community models
41-
"CommunityDetector",
42-
"read_file",
43-
]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from functools import lru_cache
2+
from typing import Union
3+
4+
from .recursive_character_splitter import (
5+
ChineseRecursiveTextSplitter,
6+
RecursiveCharacterSplitter,
7+
)
8+
9+
_MAPPING = {
10+
"en": RecursiveCharacterSplitter,
11+
"zh": ChineseRecursiveTextSplitter,
12+
}
13+
14+
SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
15+
16+
17+
@lru_cache(maxsize=None)
18+
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
19+
cls = _MAPPING[language]
20+
kwargs = dict(frozen_kwargs)
21+
return cls(**kwargs)
22+
23+
24+
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
25+
if language not in _MAPPING:
26+
raise ValueError(
27+
f"Unsupported language: {language}. "
28+
f"Supported languages are: {list(_MAPPING.keys())}"
29+
)
30+
splitter = _get_splitter(language, frozenset(kwargs.items()))
31+
return splitter.split_text(text)

webui/app.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from graphgen.models import OpenAIModel, Tokenizer
1313
from graphgen.models.llm.limitter import RPM, TPM
1414
from graphgen.utils import set_logger
15-
from webui.base import GraphGenParams
15+
from webui.base import WebuiParams
1616
from webui.cache_utils import cleanup_workspace, setup_workspace
1717
from webui.count_tokens import count_tokens
1818
from webui.i18n import Translate
@@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
6666

6767

6868
# pylint: disable=too-many-statements
69-
def run_graphgen(params, progress=gr.Progress()):
69+
def run_graphgen(params: WebuiParams, progress=gr.Progress()):
7070
def sum_tokens(client):
7171
return sum(u["total_tokens"] for u in client.token_usage)
7272

7373
config = {
7474
"if_trainee_model": params.if_trainee_model,
75-
"input_file": params.input_file,
75+
"read": {
76+
"input_file": params.input_file,
77+
},
78+
"split": {
79+
"chunk_size": params.chunk_size,
80+
"chunk_overlap": params.chunk_overlap,
81+
},
7682
"output_data_type": params.output_data_type,
7783
"output_data_format": params.output_data_format,
7884
"tokenizer": params.tokenizer,
@@ -91,7 +97,6 @@ def sum_tokens(client):
9197
"isolated_node_strategy": params.isolated_node_strategy,
9298
"loss_strategy": params.loss_strategy,
9399
},
94-
"chunk_size": params.chunk_size,
95100
}
96101

97102
env = {
@@ -284,10 +289,18 @@ def sum_tokens(client):
284289
label="Chunk Size",
285290
minimum=256,
286291
maximum=4096,
287-
value=512,
292+
value=1024,
288293
step=256,
289294
interactive=True,
290295
)
296+
chunk_overlap = gr.Slider(
297+
label="Chunk Overlap",
298+
minimum=0,
299+
maximum=500,
300+
value=100,
301+
step=100,
302+
interactive=True,
303+
)
291304
tokenizer = gr.Textbox(
292305
label="Tokenizer", value="cl100k_base", interactive=True
293306
)
@@ -499,7 +512,7 @@ def sum_tokens(client):
499512

500513
submit_btn.click(
501514
lambda *args: run_graphgen(
502-
GraphGenParams(
515+
WebuiParams(
503516
if_trainee_model=args[0],
504517
input_file=args[1],
505518
tokenizer=args[2],
@@ -518,12 +531,13 @@ def sum_tokens(client):
518531
trainee_model=args[15],
519532
api_key=args[16],
520533
chunk_size=args[17],
521-
rpm=args[18],
522-
tpm=args[19],
523-
quiz_samples=args[20],
524-
trainee_url=args[21],
525-
trainee_api_key=args[22],
526-
token_counter=args[23],
534+
chunk_overlap=args[18],
535+
rpm=args[19],
536+
tpm=args[20],
537+
quiz_samples=args[21],
538+
trainee_url=args[22],
539+
trainee_api_key=args[23],
540+
token_counter=args[24],
527541
)
528542
),
529543
inputs=[
@@ -545,6 +559,7 @@ def sum_tokens(client):
545559
trainee_model,
546560
api_key,
547561
chunk_size,
562+
chunk_overlap,
548563
rpm,
549564
tpm,
550565
quiz_samples,

webui/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@dataclass
6-
class GraphGenParams:
6+
class WebuiParams:
77
"""
88
GraphGen parameters
99
"""
@@ -26,6 +26,7 @@ class GraphGenParams:
2626
trainee_model: str
2727
api_key: str
2828
chunk_size: int
29+
chunk_overlap: int
2930
rpm: int
3031
tpm: int
3132
quiz_samples: int

0 commit comments

Comments
 (0)