Skip to content

Commit 1664f4e

Browse files
authored
Merge pull request #210 from send/refactor/feature-weight-model
refactor: feature-weight model with lextool tune
2 parents 16061bd + 57670d2 commit 1664f4e

8 files changed

Lines changed: 1040 additions & 195 deletions

File tree

engine/crates/lex-cli/src/bin/lextool.rs

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::process;
77
use clap::{Parser, Subcommand};
88
use serde::{Deserialize, Serialize};
99

10+
use lex_core::converter::tune;
1011
use lex_core::converter::{convert_nbest, convert_nbest_with_history};
1112
use lex_core::dict::connection::ConnectionMatrix;
1213
use lex_core::dict::TrieDictionary;
@@ -90,6 +91,28 @@ enum Command {
9091
history: Option<String>,
9192
},
9293

94+
/// Grid-search FeatureWeights to optimise conversion accuracy
95+
Tune {
96+
/// Path to the compiled dictionary file
97+
dict_file: String,
98+
/// Path to the compiled connection matrix file
99+
conn_file: String,
100+
/// Path to the accuracy corpus TOML file
101+
corpus_file: String,
102+
/// Filter by tag (only run cases with this tag)
103+
#[arg(long)]
104+
tag: Option<String>,
105+
/// Filter by category (only run cases in this category)
106+
#[arg(long)]
107+
category: Option<String>,
108+
/// Output as JSON instead of text
109+
#[arg(long)]
110+
json: bool,
111+
/// Number of top weight combinations to show
112+
#[arg(long, default_value = "10")]
113+
top_n: usize,
114+
},
115+
93116
/// Compare current output against a saved snapshot
94117
DiffSnapshot {
95118
/// Path to the compiled dictionary file
@@ -595,6 +618,78 @@ fn main() {
595618
);
596619
}
597620

621+
Command::Tune {
622+
dict_file,
623+
conn_file,
624+
corpus_file,
625+
tag,
626+
category,
627+
json,
628+
top_n,
629+
} => {
630+
let (dict, conn, _) = open_resources(&dict_file, Some(&conn_file), &None);
631+
let conn = conn.expect("connection matrix is required for tune");
632+
633+
// Load and parse corpus (same as Accuracy)
634+
let corpus_content = fs::read_to_string(&corpus_file).unwrap_or_else(|e| {
635+
eprintln!("Failed to read corpus file {}: {}", corpus_file, e);
636+
process::exit(1);
637+
});
638+
let corpus: AccuracyCorpus = toml::from_str(&corpus_content).unwrap_or_else(|e| {
639+
eprintln!("Failed to parse corpus TOML: {}", e);
640+
process::exit(1);
641+
});
642+
643+
// Filter and collect non-skip cases
644+
let cases: Vec<(String, String)> = corpus
645+
.cases
646+
.iter()
647+
.filter(|c| {
648+
if c.skip {
649+
return false;
650+
}
651+
if let Some(ref t) = tag {
652+
if !c.tags.contains(t) {
653+
return false;
654+
}
655+
}
656+
if let Some(ref cat) = category {
657+
if c.category != *cat {
658+
return false;
659+
}
660+
}
661+
true
662+
})
663+
.map(|c| (c.reading.clone(), c.expected.clone()))
664+
.collect();
665+
666+
if cases.is_empty() {
667+
eprintln!("No cases match the given filters");
668+
process::exit(1);
669+
}
670+
671+
let grid = tune::WeightGrid::default();
672+
let combos = grid.total_combinations();
673+
674+
eprint!("Pre-computing candidates for {} cases... ", cases.len());
675+
let tune_cases = tune::precompute_cases(&dict, &conn, &cases);
676+
eprintln!("done");
677+
678+
eprint!(
679+
"Grid search: {} combinations x {} cases... ",
680+
combos,
681+
cases.len()
682+
);
683+
let result = tune::grid_search(&tune_cases, &grid, top_n);
684+
eprintln!("done");
685+
686+
if json {
687+
print_tune_json(&result);
688+
} else {
689+
print_tune_text(&result);
690+
}
691+
}
692+
598693
Command::DiffSnapshot {
599694
dict_file,
600695
conn_file,
@@ -697,3 +792,171 @@ fn main() {
697792
}
698793
}
699794
}
795+
796+
fn print_tune_text(result: &tune::TuneResult) {
797+
let fmt_weights = |w: &tune::FeatureWeights| {
798+
format!(
799+
"lv={} te={} sk={}",
800+
w.length_variance, w.te_kanji, w.single_kanji
801+
)
802+
};
803+
804+
let fmt_rate = |e: &tune::TuneEval| {
805+
let rate = if e.total > 0 {
806+
e.pass_count as f64 / e.total as f64 * 100.0
807+
} else {
808+
0.0
809+
};
810+
format!("{:.1}% ({}/{})", rate, e.pass_count, e.total)
811+
};
812+
813+
println!();
814+
println!("=== Best Weights ===");
815+
println!(" length_variance: {}", result.best.weights.length_variance);
816+
println!(" te_kanji: {}", result.best.weights.te_kanji);
817+
println!(" single_kanji: {}", result.best.weights.single_kanji);
818+
println!(" Pass rate: {}", fmt_rate(&result.best));
819+
820+
println!();
821+
println!("=== Default Weights ===");
822+
println!(
823+
" length_variance: {}",
824+
result.default_eval.weights.length_variance
825+
);
826+
println!(
827+
" te_kanji: {}",
828+
result.default_eval.weights.te_kanji
829+
);
830+
println!(
831+
" single_kanji: {}",
832+
result.default_eval.weights.single_kanji
833+
);
834+
println!(" Pass rate: {}", fmt_rate(&result.default_eval));
835+
836+
if !result.diffs.is_empty() {
837+
let improvements: Vec<_> = result
838+
.diffs
839+
.iter()
840+
.filter(|d| d.best_pass && !d.default_pass)
841+
.collect();
842+
let regressions: Vec<_> = result
843+
.diffs
844+
.iter()
845+
.filter(|d| !d.best_pass && d.default_pass)
846+
.collect();
847+
let other: Vec<_> = result
848+
.diffs
849+
.iter()
850+
.filter(|d| d.best_pass == d.default_pass)
851+
.collect();
852+
853+
if !improvements.is_empty() {
854+
println!();
855+
println!("=== Improvements (default -> best) ===");
856+
for d in &improvements {
857+
println!(
858+
" + {}: {} (was: {})",
859+
d.reading, d.expected, d.default_top1
860+
);
861+
}
862+
}
863+
864+
if !regressions.is_empty() {
865+
println!();
866+
println!("=== Regressions (default -> best) ===");
867+
for d in &regressions {
868+
println!(" - {}: {} -> {}", d.reading, d.expected, d.best_top1);
869+
}
870+
}
871+
872+
if !other.is_empty() {
873+
println!();
874+
println!("=== Other changes ===");
875+
for d in &other {
876+
println!(
877+
" ~ {}: {} -> {} (expected: {})",
878+
d.reading, d.default_top1, d.best_top1, d.expected
879+
);
880+
}
881+
}
882+
}
883+
884+
if !result.best_failures.is_empty() {
885+
println!();
886+
println!("=== Failures (best weights) ===");
887+
for f in &result.best_failures {
888+
println!(
889+
" \u{2717} {} \u{2192} {} (got: {})",
890+
f.reading, f.expected, f.actual
891+
);
892+
}
893+
}
894+
895+
if result.top_n.len() > 1 {
896+
println!();
897+
println!("=== Top {} Weight Combinations ===", result.top_n.len());
898+
for (i, e) in result.top_n.iter().enumerate() {
899+
println!(
900+
" #{:<2} {} {}",
901+
i + 1,
902+
fmt_rate(e),
903+
fmt_weights(&e.weights)
904+
);
905+
}
906+
}
907+
}
908+
909+
fn print_tune_json(result: &tune::TuneResult) {
910+
let weight_json = |w: &tune::FeatureWeights| -> serde_json::Value {
911+
serde_json::json!({
912+
"structure": w.structure,
913+
"length_variance": w.length_variance,
914+
"te_kanji": w.te_kanji,
915+
"single_kanji": w.single_kanji,
916+
"script": w.script,
917+
})
918+
};
919+
920+
let eval_json = |e: &tune::TuneEval| -> serde_json::Value {
921+
let rate = if e.total > 0 {
922+
e.pass_count as f64 / e.total as f64 * 100.0
923+
} else {
924+
0.0
925+
};
926+
serde_json::json!({
927+
"weights": weight_json(&e.weights),
928+
"pass_count": e.pass_count,
929+
"total": e.total,
930+
"pass_rate": format!("{:.1}%", rate),
931+
})
932+
};
933+
934+
let diffs: Vec<serde_json::Value> = result
935+
.diffs
936+
.iter()
937+
.map(|d| {
938+
serde_json::json!({
939+
"reading": d.reading,
940+
"expected": d.expected,
941+
"default_top1": d.default_top1,
942+
"best_top1": d.best_top1,
943+
"default_pass": d.default_pass,
944+
"best_pass": d.best_pass,
945+
})
946+
})
947+
.collect();
948+
949+
let top_n: Vec<serde_json::Value> = result.top_n.iter().map(eval_json).collect();
950+
951+
let report = serde_json::json!({
952+
"best": eval_json(&result.best),
953+
"default": eval_json(&result.default_eval),
954+
"diffs": diffs,
955+
"top_n": top_n,
956+
});
957+
958+
println!(
959+
"{}",
960+
serde_json::to_string_pretty(&report).expect("JSON serialization failed")
961+
);
962+
}

engine/crates/lex-core/src/converter/explain.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::user_history::UserHistory;
77
use crate::settings::settings;
88

99
use super::cost::{conn_cost, script_cost, DefaultCostFunction};
10+
use super::features::{is_single_char_kanji_penalised, is_te_form_kanji_penalised};
1011
use super::lattice::{build_lattice, LatticeNode};
1112
use super::reranker;
1213
use super::viterbi::{viterbi_nbest, ScoredPath};
@@ -116,13 +117,23 @@ fn explain_segments(
116117
} else {
117118
None
118119
};
119-
let (te_penalty, sc_penalty) = if let Some(c) = conn {
120-
(
121-
reranker::te_form_kanji_penalty(prev_seg, seg, c),
122-
reranker::single_char_kanji_penalty(seg, i, &scored.segments, c, Some(dict)),
123-
)
120+
let te_penalty = if let Some(c) = conn {
121+
if is_te_form_kanji_penalised(seg, prev_seg, c) {
122+
settings().reranker.te_form_kanji_penalty
123+
} else {
124+
0
125+
}
124126
} else {
125-
(0, 0)
127+
0
128+
};
129+
let sc_penalty = if let Some(c) = conn {
130+
if is_single_char_kanji_penalised(seg, i, &scored.segments, c, Some(dict)) {
131+
settings().reranker.single_char_kanji_penalty
132+
} else {
133+
0
134+
}
135+
} else {
136+
0
126137
};
127138
ExplainSegment {
128139
reading: seg.reading.clone(),

0 commit comments

Comments
 (0)