@@ -7,6 +7,7 @@ use std::process;
77use clap:: { Parser , Subcommand } ;
88use serde:: { Deserialize , Serialize } ;
99
10+ use lex_core:: converter:: tune;
1011use lex_core:: converter:: { convert_nbest, convert_nbest_with_history} ;
1112use lex_core:: dict:: connection:: ConnectionMatrix ;
1213use lex_core:: dict:: TrieDictionary ;
@@ -90,6 +91,28 @@ enum Command {
9091 history : Option < String > ,
9192 } ,
9293
94+ /// Grid-search FeatureWeights to optimise conversion accuracy
95+ Tune {
96+ /// Path to the compiled dictionary file
97+ dict_file : String ,
98+ /// Path to the compiled connection matrix file
99+ conn_file : String ,
100+ /// Path to the accuracy corpus TOML file
101+ corpus_file : String ,
102+ /// Filter by tag (only run cases with this tag)
103+ #[ arg( long) ]
104+ tag : Option < String > ,
105+ /// Filter by category (only run cases in this category)
106+ #[ arg( long) ]
107+ category : Option < String > ,
108+ /// Output as JSON instead of text
109+ #[ arg( long) ]
110+ json : bool ,
111+ /// Number of top weight combinations to show
112+ #[ arg( long, default_value = "10" ) ]
113+ top_n : usize ,
114+ } ,
115+
93116 /// Compare current output against a saved snapshot
94117 DiffSnapshot {
95118 /// Path to the compiled dictionary file
@@ -595,6 +618,78 @@ fn main() {
595618 ) ;
596619 }
597620
621+ Command :: Tune {
622+ dict_file,
623+ conn_file,
624+ corpus_file,
625+ tag,
626+ category,
627+ json,
628+ top_n,
629+ } => {
630+ let ( dict, conn, _) = open_resources ( & dict_file, Some ( & conn_file) , & None ) ;
631+ let conn = conn. expect ( "connection matrix is required for tune" ) ;
632+
633+ // Load and parse corpus (same as Accuracy)
634+ let corpus_content = fs:: read_to_string ( & corpus_file) . unwrap_or_else ( |e| {
635+ eprintln ! ( "Failed to read corpus file {}: {}" , corpus_file, e) ;
636+ process:: exit ( 1 ) ;
637+ } ) ;
638+ let corpus: AccuracyCorpus = toml:: from_str ( & corpus_content) . unwrap_or_else ( |e| {
639+ eprintln ! ( "Failed to parse corpus TOML: {}" , e) ;
640+ process:: exit ( 1 ) ;
641+ } ) ;
642+
643+ // Filter and collect non-skip cases
644+ let cases: Vec < ( String , String ) > = corpus
645+ . cases
646+ . iter ( )
647+ . filter ( |c| {
648+ if c. skip {
649+ return false ;
650+ }
651+ if let Some ( ref t) = tag {
652+ if !c. tags . contains ( t) {
653+ return false ;
654+ }
655+ }
656+ if let Some ( ref cat) = category {
657+ if c. category != * cat {
658+ return false ;
659+ }
660+ }
661+ true
662+ } )
663+ . map ( |c| ( c. reading . clone ( ) , c. expected . clone ( ) ) )
664+ . collect ( ) ;
665+
666+ if cases. is_empty ( ) {
667+ eprintln ! ( "No cases match the given filters" ) ;
668+ process:: exit ( 1 ) ;
669+ }
670+
671+ let grid = tune:: WeightGrid :: default ( ) ;
672+ let combos = grid. total_combinations ( ) ;
673+
674+ eprint ! ( "Pre-computing candidates for {} cases... " , cases. len( ) ) ;
675+ let tune_cases = tune:: precompute_cases ( & dict, & conn, & cases) ;
676+ eprintln ! ( "done" ) ;
677+
678+ eprint ! (
679+ "Grid search: {} combinations x {} cases... " ,
680+ combos,
681+ cases. len( )
682+ ) ;
683+ let result = tune:: grid_search ( & tune_cases, & grid, top_n) ;
684+ eprintln ! ( "done" ) ;
685+
686+ if json {
687+ print_tune_json ( & result) ;
688+ } else {
689+ print_tune_text ( & result) ;
690+ }
691+ }
692+
598693 Command :: DiffSnapshot {
599694 dict_file,
600695 conn_file,
@@ -697,3 +792,171 @@ fn main() {
697792 }
698793 }
699794}
795+
796+ fn print_tune_text ( result : & tune:: TuneResult ) {
797+ let fmt_weights = |w : & tune:: FeatureWeights | {
798+ format ! (
799+ "lv={} te={} sk={}" ,
800+ w. length_variance, w. te_kanji, w. single_kanji
801+ )
802+ } ;
803+
804+ let fmt_rate = |e : & tune:: TuneEval | {
805+ let rate = if e. total > 0 {
806+ e. pass_count as f64 / e. total as f64 * 100.0
807+ } else {
808+ 0.0
809+ } ;
810+ format ! ( "{:.1}% ({}/{})" , rate, e. pass_count, e. total)
811+ } ;
812+
813+ println ! ( ) ;
814+ println ! ( "=== Best Weights ===" ) ;
815+ println ! ( " length_variance: {}" , result. best. weights. length_variance) ;
816+ println ! ( " te_kanji: {}" , result. best. weights. te_kanji) ;
817+ println ! ( " single_kanji: {}" , result. best. weights. single_kanji) ;
818+ println ! ( " Pass rate: {}" , fmt_rate( & result. best) ) ;
819+
820+ println ! ( ) ;
821+ println ! ( "=== Default Weights ===" ) ;
822+ println ! (
823+ " length_variance: {}" ,
824+ result. default_eval. weights. length_variance
825+ ) ;
826+ println ! (
827+ " te_kanji: {}" ,
828+ result. default_eval. weights. te_kanji
829+ ) ;
830+ println ! (
831+ " single_kanji: {}" ,
832+ result. default_eval. weights. single_kanji
833+ ) ;
834+ println ! ( " Pass rate: {}" , fmt_rate( & result. default_eval) ) ;
835+
836+ if !result. diffs . is_empty ( ) {
837+ let improvements: Vec < _ > = result
838+ . diffs
839+ . iter ( )
840+ . filter ( |d| d. best_pass && !d. default_pass )
841+ . collect ( ) ;
842+ let regressions: Vec < _ > = result
843+ . diffs
844+ . iter ( )
845+ . filter ( |d| !d. best_pass && d. default_pass )
846+ . collect ( ) ;
847+ let other: Vec < _ > = result
848+ . diffs
849+ . iter ( )
850+ . filter ( |d| d. best_pass == d. default_pass )
851+ . collect ( ) ;
852+
853+ if !improvements. is_empty ( ) {
854+ println ! ( ) ;
855+ println ! ( "=== Improvements (default -> best) ===" ) ;
856+ for d in & improvements {
857+ println ! (
858+ " + {}: {} (was: {})" ,
859+ d. reading, d. expected, d. default_top1
860+ ) ;
861+ }
862+ }
863+
864+ if !regressions. is_empty ( ) {
865+ println ! ( ) ;
866+ println ! ( "=== Regressions (default -> best) ===" ) ;
867+ for d in & regressions {
868+ println ! ( " - {}: {} -> {}" , d. reading, d. expected, d. best_top1) ;
869+ }
870+ }
871+
872+ if !other. is_empty ( ) {
873+ println ! ( ) ;
874+ println ! ( "=== Other changes ===" ) ;
875+ for d in & other {
876+ println ! (
877+ " ~ {}: {} -> {} (expected: {})" ,
878+ d. reading, d. default_top1, d. best_top1, d. expected
879+ ) ;
880+ }
881+ }
882+ }
883+
884+ if !result. best_failures . is_empty ( ) {
885+ println ! ( ) ;
886+ println ! ( "=== Failures (best weights) ===" ) ;
887+ for f in & result. best_failures {
888+ println ! (
889+ " \u{2717} {} \u{2192} {} (got: {})" ,
890+ f. reading, f. expected, f. actual
891+ ) ;
892+ }
893+ }
894+
895+ if result. top_n . len ( ) > 1 {
896+ println ! ( ) ;
897+ println ! ( "=== Top {} Weight Combinations ===" , result. top_n. len( ) ) ;
898+ for ( i, e) in result. top_n . iter ( ) . enumerate ( ) {
899+ println ! (
900+ " #{:<2} {} {}" ,
901+ i + 1 ,
902+ fmt_rate( e) ,
903+ fmt_weights( & e. weights)
904+ ) ;
905+ }
906+ }
907+ }
908+
909+ fn print_tune_json ( result : & tune:: TuneResult ) {
910+ let weight_json = |w : & tune:: FeatureWeights | -> serde_json:: Value {
911+ serde_json:: json!( {
912+ "structure" : w. structure,
913+ "length_variance" : w. length_variance,
914+ "te_kanji" : w. te_kanji,
915+ "single_kanji" : w. single_kanji,
916+ "script" : w. script,
917+ } )
918+ } ;
919+
920+ let eval_json = |e : & tune:: TuneEval | -> serde_json:: Value {
921+ let rate = if e. total > 0 {
922+ e. pass_count as f64 / e. total as f64 * 100.0
923+ } else {
924+ 0.0
925+ } ;
926+ serde_json:: json!( {
927+ "weights" : weight_json( & e. weights) ,
928+ "pass_count" : e. pass_count,
929+ "total" : e. total,
930+ "pass_rate" : format!( "{:.1}%" , rate) ,
931+ } )
932+ } ;
933+
934+ let diffs: Vec < serde_json:: Value > = result
935+ . diffs
936+ . iter ( )
937+ . map ( |d| {
938+ serde_json:: json!( {
939+ "reading" : d. reading,
940+ "expected" : d. expected,
941+ "default_top1" : d. default_top1,
942+ "best_top1" : d. best_top1,
943+ "default_pass" : d. default_pass,
944+ "best_pass" : d. best_pass,
945+ } )
946+ } )
947+ . collect ( ) ;
948+
949+ let top_n: Vec < serde_json:: Value > = result. top_n . iter ( ) . map ( eval_json) . collect ( ) ;
950+
951+ let report = serde_json:: json!( {
952+ "best" : eval_json( & result. best) ,
953+ "default" : eval_json( & result. default_eval) ,
954+ "diffs" : diffs,
955+ "top_n" : top_n,
956+ } ) ;
957+
958+ println ! (
959+ "{}" ,
960+ serde_json:: to_string_pretty( & report) . expect( "JSON serialization failed" )
961+ ) ;
962+ }
0 commit comments