1212from graphgen .models import OpenAIModel , Tokenizer
1313from graphgen .models .llm .limitter import RPM , TPM
1414from graphgen .utils import set_logger
15- from webui .base import GraphGenParams
15+ from webui .base import WebuiParams
1616from webui .cache_utils import cleanup_workspace , setup_workspace
1717from webui .count_tokens import count_tokens
1818from webui .i18n import Translate
@@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
6666
6767
6868# pylint: disable=too-many-statements
69- def run_graphgen (params , progress = gr .Progress ()):
69+ def run_graphgen (params : WebuiParams , progress = gr .Progress ()):
7070 def sum_tokens (client ):
7171 return sum (u ["total_tokens" ] for u in client .token_usage )
7272
7373 config = {
7474 "if_trainee_model" : params .if_trainee_model ,
75- "input_file" : params .input_file ,
75+ "read" : {
76+ "input_file" : params .input_file ,
77+ },
78+ "split" : {
79+ "chunk_size" : params .chunk_size ,
80+ "chunk_overlap" : params .chunk_overlap ,
81+ },
7682 "output_data_type" : params .output_data_type ,
7783 "output_data_format" : params .output_data_format ,
7884 "tokenizer" : params .tokenizer ,
@@ -91,7 +97,6 @@ def sum_tokens(client):
9197 "isolated_node_strategy" : params .isolated_node_strategy ,
9298 "loss_strategy" : params .loss_strategy ,
9399 },
94- "chunk_size" : params .chunk_size ,
95100 }
96101
97102 env = {
@@ -284,10 +289,18 @@ def sum_tokens(client):
284289 label = "Chunk Size" ,
285290 minimum = 256 ,
286291 maximum = 4096 ,
287- value = 512 ,
292+ value = 1024 ,
288293 step = 256 ,
289294 interactive = True ,
290295 )
296+ chunk_overlap = gr .Slider (
297+ label = "Chunk Overlap" ,
298+ minimum = 0 ,
299+ maximum = 500 ,
300+ value = 100 ,
301+ step = 100 ,
302+ interactive = True ,
303+ )
291304 tokenizer = gr .Textbox (
292305 label = "Tokenizer" , value = "cl100k_base" , interactive = True
293306 )
@@ -499,7 +512,7 @@ def sum_tokens(client):
499512
500513 submit_btn .click (
501514 lambda * args : run_graphgen (
502- GraphGenParams (
515+ WebuiParams (
503516 if_trainee_model = args [0 ],
504517 input_file = args [1 ],
505518 tokenizer = args [2 ],
@@ -518,12 +531,13 @@ def sum_tokens(client):
518531 trainee_model = args [15 ],
519532 api_key = args [16 ],
520533 chunk_size = args [17 ],
521- rpm = args [18 ],
522- tpm = args [19 ],
523- quiz_samples = args [20 ],
524- trainee_url = args [21 ],
525- trainee_api_key = args [22 ],
526- token_counter = args [23 ],
534+ chunk_overlap = args [18 ],
535+ rpm = args [19 ],
536+ tpm = args [20 ],
537+ quiz_samples = args [21 ],
538+ trainee_url = args [22 ],
539+ trainee_api_key = args [23 ],
540+ token_counter = args [24 ],
527541 )
528542 ),
529543 inputs = [
@@ -545,6 +559,7 @@ def sum_tokens(client):
545559 trainee_model ,
546560 api_key ,
547561 chunk_size ,
562+ chunk_overlap ,
548563 rpm ,
549564 tpm ,
550565 quiz_samples ,
0 commit comments