|
| 1 | +import mxnet as mx |
| 2 | +from sockeye import inference, model |
| 3 | +import sentencepiece as spm |
| 4 | +import warnings |
| 5 | +warnings.filterwarnings("ignore") |
| 6 | + |
| 7 | + |
| 8 | +device = mx.cpu() |
| 9 | +model_folder = "spoken2symbol" |
| 10 | +spm_path = model_folder + "/spm.model" |
| 11 | + |
| 12 | + |
| 13 | +sockeye_models, sockeye_source_vocabs, sockeye_target_vocabs = model.load_models( |
| 14 | + context=device, dtype=None, model_folders=[model_folder], inference_only=True |
| 15 | +) |
| 16 | +sp = spm.SentencePieceProcessor(model_file=spm_path) |
| 17 | + |
| 18 | + |
| 19 | +language_code = "en" |
| 20 | +country_code = "ase" # "us" |
| 21 | +translation_type = "sent" |
| 22 | +n_best = 3 |
| 23 | +beam_size = n_best |
| 24 | + |
| 25 | + |
| 26 | +def translate(text): |
| 27 | + |
| 28 | + tag_str = f"<2{language_code}> <4{country_code}> <{translation_type}>" |
| 29 | + formatted = f"{tag_str} {text}" |
| 30 | + encoded = " ".join(sp.encode(formatted, out_type=str)) |
| 31 | + print(encoded) |
| 32 | + print() |
| 33 | + |
| 34 | + translator = inference.Translator( |
| 35 | + context=device, |
| 36 | + ensemble_mode="linear", |
| 37 | + scorer=inference.CandidateScorer(), |
| 38 | + output_scores=True, |
| 39 | + batch_size=1, |
| 40 | + beam_size=beam_size, |
| 41 | + beam_search_stop="all", |
| 42 | + nbest_size=n_best, |
| 43 | + models=sockeye_models, |
| 44 | + source_vocabs=sockeye_source_vocabs, |
| 45 | + target_vocabs=sockeye_target_vocabs, |
| 46 | + ) |
| 47 | + |
| 48 | + encoded = inference.make_input_from_plain_string(0, encoded) |
| 49 | + output = translator.translate([encoded])[0] |
| 50 | + print(output) |
| 51 | + print() |
| 52 | + |
| 53 | + translations = [] |
| 54 | + symbols_candidates = output.nbest_translations |
| 55 | + factors_candidates = output.nbest_factor_translations |
| 56 | + for symbols, factors in zip(symbols_candidates, factors_candidates): |
| 57 | + symbols = symbols.split(" ") |
| 58 | + xs = factors[0].split(" ") |
| 59 | + ys = factors[1].split(" ") |
| 60 | + fsw = "" |
| 61 | + |
| 62 | + for i, (symbol, x, y) in enumerate(zip(symbols, xs, ys)): |
| 63 | + if symbol != "P": |
| 64 | + if i != 0: |
| 65 | + if ( |
| 66 | + not symbol.startswith("S") |
| 67 | + or symbol.startswith("S387") |
| 68 | + or symbol.startswith("S388") |
| 69 | + or symbol.startswith("S389") |
| 70 | + or symbol.startswith("S38a") |
| 71 | + or symbol.startswith("S38b") |
| 72 | + ): |
| 73 | + fsw += " " |
| 74 | + fsw += symbol |
| 75 | + fsw += x |
| 76 | + fsw += "x" |
| 77 | + fsw += y |
| 78 | + |
| 79 | + translations.append(fsw) |
| 80 | + |
| 81 | + return translations |
| 82 | + |
| 83 | +if __name__ == "__main__": |
| 84 | + spoken = "hi" |
| 85 | + symbol = translate(spoken) |
| 86 | + print(symbol) |
| 87 | + |
| 88 | + from signwriting.visualizer.visualize import signwriting_to_image |
| 89 | + |
| 90 | + fsw = symbol[0] |
| 91 | + img = signwriting_to_image(fsw) |
| 92 | + img.save("sign.png") |
0 commit comments