Skip to content

Commit 16061bd

Browse files
authored
Merge pull request #209 from send/refactor/absorb-reranker-costs
refactor: absorb 3 reranker heuristics into compile-time dictionary costs
2 parents 1b0bcde + 164b4e6 commit 16061bd

8 files changed

Lines changed: 66 additions & 214 deletions

File tree

engine/crates/lex-cli/src/bin/dictool.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ enum Command {
3030
input_dir: String,
3131
/// Output file
3232
output_file: String,
33+
/// Mozc id.def for compile-time cost adjustments (person name, pronoun)
34+
#[arg(long)]
35+
id_def: Option<String>,
3336
},
3437
/// Compile connection matrix
3538
CompileConn {
@@ -221,7 +224,8 @@ fn main() {
221224
source,
222225
input_dir,
223226
output_file,
224-
} => dict_ops::compile(&source, &input_dir, &output_file),
227+
id_def,
228+
} => dict_ops::compile(&source, &input_dir, &output_file, id_def.as_deref()),
225229
Command::CompileConn {
226230
input_txt,
227231
output_file,

engine/crates/lex-cli/src/commands/dict_ops.rs

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::{HashMap, HashSet};
22
use std::fs;
3-
use std::path::Path;
3+
use std::path::{Path, PathBuf};
44
use std::process;
55

66
use crate::dict_source;
@@ -29,7 +29,12 @@ pub fn fetch(source_name: &str, output_dir: &str) {
2929
);
3030
}
3131

32-
pub fn compile(source_name: &str, input_dir: &str, output_file: &str) {
32+
/// Cost offsets applied at dictionary compile time to eliminate reranker heuristics.
33+
const PERSON_NAME_COST_OFFSET: i16 = 2000;
34+
const PRONOUN_COST_OFFSET: i16 = -3500;
35+
const NON_INDEPENDENT_KANJI_COST_OFFSET: i16 = 1500;
36+
37+
pub fn compile(source_name: &str, input_dir: &str, output_file: &str, id_def: Option<&str>) {
3338
let dict_source = dict_source::from_name(source_name).unwrap_or_else(|| {
3439
eprintln!("Error: unknown source '{source_name}' (available: mozc)");
3540
process::exit(1);
@@ -42,11 +47,58 @@ pub fn compile(source_name: &str, input_dir: &str, output_file: &str) {
4247
}
4348

4449
eprintln!("Source: {source_name}");
45-
let entries = die!(
50+
let mut entries = die!(
4651
dict_source.parse_dir(input_path),
4752
"Error parsing dictionary: {}"
4853
);
4954

55+
// Apply compile-time cost offsets based on morpheme roles.
56+
// Auto-detect id.def in input_dir if --id-def is not specified.
57+
let id_def_path = id_def.map(PathBuf::from).or_else(|| {
58+
let auto = input_path.join("id.def");
59+
if auto.is_file() {
60+
eprintln!("Auto-detected id.def at {}", auto.display());
61+
Some(auto)
62+
} else {
63+
None
64+
}
65+
});
66+
if let Some(id_def_path) = &id_def_path {
67+
let roles = die!(
68+
pos_map::morpheme_roles(id_def_path),
69+
"Error loading morpheme roles: {}"
70+
);
71+
let mut adjusted = 0usize;
72+
for entry_list in entries.values_mut() {
73+
for entry in entry_list.iter_mut() {
74+
let id = entry.left_id as usize;
75+
if id >= roles.len() {
76+
eprintln!(
77+
"Warning: left_id {} out of roles table range ({}), skipping entry '{}'",
78+
id,
79+
roles.len(),
80+
entry.surface
81+
);
82+
continue;
83+
}
84+
let role = roles[id];
85+
let offset = match role {
86+
pos_map::ROLE_PERSON_NAME => PERSON_NAME_COST_OFFSET,
87+
pos_map::ROLE_PRONOUN => PRONOUN_COST_OFFSET,
88+
pos_map::ROLE_NON_INDEPENDENT
89+
if entry.surface.chars().any(lex_core::unicode::is_kanji) =>
90+
{
91+
NON_INDEPENDENT_KANJI_COST_OFFSET
92+
}
93+
_ => continue,
94+
};
95+
entry.cost = entry.cost.saturating_add(offset);
96+
adjusted += 1;
97+
}
98+
}
99+
eprintln!("Adjusted {adjusted} entries (person_name: +{PERSON_NAME_COST_OFFSET}, pronoun: {PRONOUN_COST_OFFSET}, non_independent_kanji: +{NON_INDEPENDENT_KANJI_COST_OFFSET})");
100+
}
101+
50102
let reading_count = entries.len();
51103
let entry_count: usize = entries.values().map(|v| v.len()).sum();
52104

engine/crates/lex-core/src/converter/explain.rs

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,6 @@ pub struct ExplainSegment {
7575
pub script_cost: i64,
7676
/// Connection cost from BOS or previous segment.
7777
pub connection_cost: i64,
78-
/// Non-independent kanji penalty applied.
79-
pub non_independent_kanji_penalty: i64,
80-
/// Pronoun cost bonus applied (positive value, subtracted from cost).
81-
pub pronoun_bonus: i64,
8278
/// Te-form kanji penalty applied.
8379
pub te_form_kanji_penalty: i64,
8480
/// Single-char kanji content-word penalty applied.
@@ -120,15 +116,13 @@ fn explain_segments(
120116
} else {
121117
None
122118
};
123-
let (ni_penalty, p_bonus, te_penalty, sc_penalty) = if let Some(c) = conn {
119+
let (te_penalty, sc_penalty) = if let Some(c) = conn {
124120
(
125-
reranker::non_independent_kanji_penalty(seg, c),
126-
reranker::pronoun_bonus(seg, c),
127121
reranker::te_form_kanji_penalty(prev_seg, seg, c),
128122
reranker::single_char_kanji_penalty(seg, i, &scored.segments, c, Some(dict)),
129123
)
130124
} else {
131-
(0, 0, 0, 0)
125+
(0, 0)
132126
};
133127
ExplainSegment {
134128
reading: seg.reading.clone(),
@@ -137,8 +131,6 @@ fn explain_segments(
137131
segment_penalty: settings().cost.segment_penalty,
138132
script_cost: script_cost(&seg.surface, seg.reading.chars().count()),
139133
connection_cost: connection,
140-
non_independent_kanji_penalty: ni_penalty,
141-
pronoun_bonus: p_bonus,
142134
te_form_kanji_penalty: te_penalty,
143135
single_char_kanji_penalty: sc_penalty,
144136
left_id: seg.left_id,
@@ -289,16 +281,6 @@ pub fn format_text(result: &ExplainResult) -> String {
289281
seg_label
290282
};
291283
let conn_label = if j == 0 { "BOS->" } else { "conn=" };
292-
let ni_str = if seg.non_independent_kanji_penalty > 0 {
293-
format!(" ni_kanji={:<+6}", seg.non_independent_kanji_penalty)
294-
} else {
295-
String::new()
296-
};
297-
let pronoun_str = if seg.pronoun_bonus > 0 {
298-
format!(" pronoun={:<+6}", -(seg.pronoun_bonus))
299-
} else {
300-
String::new()
301-
};
302284
let te_str = if seg.te_form_kanji_penalty > 0 {
303285
format!(" teK={:<+6}", seg.te_form_kanji_penalty)
304286
} else {
@@ -310,16 +292,14 @@ pub fn format_text(result: &ExplainResult) -> String {
310292
String::new()
311293
};
312294
out.push_str(&format!(
313-
" seg[{}]: {} word={:<6} penalty={:<5} script={:<6} {}{}{}{}{}{}\n",
295+
" seg[{}]: {} word={:<6} penalty={:<5} script={:<6} {}{}{}{}\n",
314296
j,
315297
padded,
316298
seg.word_cost,
317299
seg.segment_penalty,
318300
seg.script_cost,
319301
conn_label,
320302
seg.connection_cost,
321-
ni_str,
322-
pronoun_str,
323303
te_str,
324304
single_char_str,
325305
));

engine/crates/lex-core/src/converter/reranker.rs

Lines changed: 1 addition & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,6 @@ use crate::user_history::UserHistory;
99
use super::cost::{conn_cost, script_cost};
1010
use super::viterbi::{RichSegment, ScoredPath};
1111

12-
/// Non-independent kanji penalty for a segment.
13-
/// Returns penalty (> 0) if the segment is non-independent (形式名詞/補助動詞) with kanji surface.
14-
pub(super) fn non_independent_kanji_penalty(seg: &RichSegment, conn: &ConnectionMatrix) -> i64 {
15-
if conn.is_non_independent(seg.left_id) && seg.surface.chars().any(is_kanji) {
16-
settings().reranker.non_independent_kanji_penalty
17-
} else {
18-
0
19-
}
20-
}
21-
22-
/// Pronoun cost bonus for a segment (positive value = cost reduction).
23-
pub(super) fn pronoun_bonus(seg: &RichSegment, conn: &ConnectionMatrix) -> i64 {
24-
if conn.is_pronoun(seg.left_id) {
25-
settings().reranker.pronoun_cost_bonus
26-
} else {
27-
0
28-
}
29-
}
30-
3112
/// Te-form kanji penalty for a segment that follows て/で.
3213
/// `prev` is the preceding segment (None for the first segment).
3314
pub(super) fn te_form_kanji_penalty(
@@ -46,16 +27,6 @@ pub(super) fn te_form_kanji_penalty(
4627
0
4728
}
4829

49-
/// Person name penalty for a segment.
50-
/// Returns penalty (> 0) if the segment is a person name (人名: 一般/姓/名; role == 6).
51-
pub(super) fn person_name_penalty(seg: &RichSegment, conn: &ConnectionMatrix) -> i64 {
52-
if conn.is_person_name(seg.left_id) {
53-
settings().reranker.person_name_penalty
54-
} else {
55-
0
56-
}
57-
}
58-
5930
/// Single-char kanji content-word penalty with dictionary compound exemption.
6031
pub(super) fn single_char_kanji_penalty(
6132
seg: &RichSegment,
@@ -222,20 +193,16 @@ pub fn rerank(
222193
.sum();
223194
path.viterbi_cost += total_script;
224195

225-
// Per-segment penalties: non-independent kanji, pronoun bonus,
226-
// te-form kanji, single-char kanji content-word.
196+
// Per-segment penalties: te-form kanji, single-char kanji content-word.
227197
if let Some(conn) = conn {
228198
for (i, seg) in path.segments.iter().enumerate() {
229199
let prev = if i > 0 {
230200
Some(&path.segments[i - 1])
231201
} else {
232202
None
233203
};
234-
path.viterbi_cost += non_independent_kanji_penalty(seg, conn);
235-
path.viterbi_cost -= pronoun_bonus(seg, conn);
236204
path.viterbi_cost += te_form_kanji_penalty(prev, seg, conn);
237205
path.viterbi_cost += single_char_kanji_penalty(seg, i, &path.segments, conn, dict);
238-
path.viterbi_cost += person_name_penalty(seg, conn);
239206
}
240207
}
241208
}
@@ -315,27 +282,6 @@ mod tests {
315282
}
316283
}
317284

318-
#[test]
319-
fn non_independent_kanji_penalty_applied() {
320-
// ID 2 = non-independent (role 4), ID 1 = content word (role 0)
321-
let roles = vec![0u8, 0, 4];
322-
let conn = conn_with_roles(roles);
323-
324-
// Path A: こと (hiragana, non-independent) — no penalty
325-
// Path B: 事 (kanji, non-independent) — penalty applied
326-
let mut paths = vec![
327-
path(vec![seg("こと", "事", 2)], 100),
328-
path(vec![seg("こと", "こと", 2)], 100),
329-
];
330-
331-
rerank(&mut paths, Some(&conn), None);
332-
333-
// The hiragana path should rank higher (lower cost)
334-
assert_eq!(paths[0].segments[0].surface, "こと");
335-
assert_eq!(paths[1].segments[0].surface, "事");
336-
assert!(paths[0].viterbi_cost < paths[1].viterbi_cost);
337-
}
338-
339285
/// Build a minimal ConnectionMatrix with the given roles vector and
340286
/// function-word ID range.
341287
fn conn_with_roles_and_fw(roles: Vec<u8>, fw_min: u16, fw_max: u16) -> ConnectionMatrix {
@@ -344,30 +290,6 @@ mod tests {
344290
ConnectionMatrix::new_owned(num_ids, fw_min, fw_max, roles, costs)
345291
}
346292

347-
#[test]
348-
fn non_independent_kanji_penalty_not_applied_to_content_words() {
349-
// ID 1 = content word (role 0)
350-
let roles = vec![0u8, 0];
351-
let conn = conn_with_roles(roles);
352-
353-
// Both paths use content word IDs — no non-independent penalty
354-
let mut paths = vec![
355-
path(vec![seg("こと", "事", 1)], 100),
356-
path(vec![seg("こと", "こと", 1)], 100),
357-
];
358-
359-
rerank(&mut paths, Some(&conn), None);
360-
361-
// Costs should differ only by script cost, not by non-independent penalty
362-
let penalty = settings().reranker.non_independent_kanji_penalty;
363-
let cost_diff = (paths[1].viterbi_cost - paths[0].viterbi_cost).abs();
364-
assert!(
365-
cost_diff < penalty,
366-
"no non-independent penalty should be applied: diff = {}",
367-
cost_diff
368-
);
369-
}
370-
371293
#[test]
372294
fn te_form_kanji_penalty_applied() {
373295
// ID 2 = function word (fw_min=2, fw_max=2), ID 1 = content word
@@ -472,35 +394,6 @@ mod tests {
472394
assert!(paths[0].viterbi_cost < paths[1].viterbi_cost);
473395
}
474396

475-
#[test]
476-
fn pronoun_bonus_applied() {
477-
// ID 2 = pronoun (role 5), ID 1 = content word (role 0)
478-
let roles = vec![0u8, 0, 5];
479-
let conn = conn_with_roles(roles);
480-
481-
// Both paths have the same surface (hiragana) to isolate pronoun bonus.
482-
// Path A: pronoun POS (id=2) — bonus applied
483-
// Path B: content word POS (id=1) — no bonus
484-
let mut paths = vec![
485-
path(vec![seg("どれ", "どれ", 2)], 1000),
486-
path(vec![seg("どれ", "どれ", 1)], 1000),
487-
];
488-
489-
rerank(&mut paths, Some(&conn), None);
490-
491-
// The pronoun path should rank higher (lower cost) after bonus
492-
assert_eq!(
493-
paths[0].segments[0].left_id, 2,
494-
"pronoun path should rank first"
495-
);
496-
let bonus = settings().reranker.pronoun_cost_bonus;
497-
let diff = paths[1].viterbi_cost - paths[0].viterbi_cost;
498-
assert_eq!(
499-
diff, bonus,
500-
"cost difference should equal pronoun_cost_bonus"
501-
);
502-
}
503-
504397
/// A minimal dictionary for testing compound exemption.
505398
struct MockDict {
506399
entries: Vec<(String, Vec<DictEntry>)>,
@@ -720,35 +613,6 @@ mod tests {
720613
);
721614
}
722615

723-
#[test]
724-
fn person_name_penalty_applied() {
725-
// ID 2 = person name (role 6), ID 1 = content word (role 0)
726-
let roles = vec![0u8, 0, 6];
727-
let conn = conn_with_roles(roles);
728-
729-
// Both paths have the same hiragana surface to isolate person name penalty.
730-
// Path A: person name POS (id=2) — penalty applied
731-
// Path B: content word POS (id=1) — no penalty
732-
let mut paths = vec![
733-
path(vec![seg("にしま", "にしま", 2)], 1000),
734-
path(vec![seg("にしま", "にしま", 1)], 1000),
735-
];
736-
737-
rerank(&mut paths, Some(&conn), None);
738-
739-
// The content word path should rank higher (lower cost)
740-
assert_eq!(
741-
paths[0].segments[0].left_id, 1,
742-
"content word path should rank first"
743-
);
744-
let penalty = settings().reranker.person_name_penalty;
745-
let diff = paths[1].viterbi_cost - paths[0].viterbi_cost;
746-
assert_eq!(
747-
diff, penalty,
748-
"cost difference should equal person_name_penalty"
749-
);
750-
}
751-
752616
#[test]
753617
fn te_form_kanji_penalty_not_applied_to_non_te_function_word() {
754618
// ID 2 = function word (fw_min=2, fw_max=2), ID 1 = content word

engine/crates/lex-core/src/default_settings.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@ unknown_word_cost = 10000
99
[reranker]
1010
length_variance_weight = 2000
1111
structure_cost_filter = 6000
12-
non_independent_kanji_penalty = 1500
1312
te_form_kanji_penalty = 3500
14-
pronoun_cost_bonus = 3500
1513
single_char_kanji_penalty = 4000
16-
person_name_penalty = 2000
1714
structure_cost_transition_cap = 5000
1815

1916
[history]

0 commit comments

Comments
 (0)