Skip to content

Commit 02d9e53

Browse files
authored
Create main.py
1 parent b0b25d7 commit 02d9e53

1 file changed

Lines changed: 92 additions & 0 deletions

File tree

text-to-symbol/main.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import mxnet as mx
2+
from sockeye import inference, model
3+
import sentencepiece as spm
4+
import warnings
5+
warnings.filterwarnings("ignore")
6+
7+
8+
device = mx.cpu()
9+
model_folder = "spoken2symbol"
10+
spm_path = model_folder + "/spm.model"
11+
12+
13+
sockeye_models, sockeye_source_vocabs, sockeye_target_vocabs = model.load_models(
14+
context=device, dtype=None, model_folders=[model_folder], inference_only=True
15+
)
16+
sp = spm.SentencePieceProcessor(model_file=spm_path)
17+
18+
19+
language_code = "en"
20+
country_code = "ase" # "us"
21+
translation_type = "sent"
22+
n_best = 3
23+
beam_size = n_best
24+
25+
26+
def translate(text):
27+
28+
tag_str = f"<2{language_code}> <4{country_code}> <{translation_type}>"
29+
formatted = f"{tag_str} {text}"
30+
encoded = " ".join(sp.encode(formatted, out_type=str))
31+
print(encoded)
32+
print()
33+
34+
translator = inference.Translator(
35+
context=device,
36+
ensemble_mode="linear",
37+
scorer=inference.CandidateScorer(),
38+
output_scores=True,
39+
batch_size=1,
40+
beam_size=beam_size,
41+
beam_search_stop="all",
42+
nbest_size=n_best,
43+
models=sockeye_models,
44+
source_vocabs=sockeye_source_vocabs,
45+
target_vocabs=sockeye_target_vocabs,
46+
)
47+
48+
encoded = inference.make_input_from_plain_string(0, encoded)
49+
output = translator.translate([encoded])[0]
50+
print(output)
51+
print()
52+
53+
translations = []
54+
symbols_candidates = output.nbest_translations
55+
factors_candidates = output.nbest_factor_translations
56+
for symbols, factors in zip(symbols_candidates, factors_candidates):
57+
symbols = symbols.split(" ")
58+
xs = factors[0].split(" ")
59+
ys = factors[1].split(" ")
60+
fsw = ""
61+
62+
for i, (symbol, x, y) in enumerate(zip(symbols, xs, ys)):
63+
if symbol != "P":
64+
if i != 0:
65+
if (
66+
not symbol.startswith("S")
67+
or symbol.startswith("S387")
68+
or symbol.startswith("S388")
69+
or symbol.startswith("S389")
70+
or symbol.startswith("S38a")
71+
or symbol.startswith("S38b")
72+
):
73+
fsw += " "
74+
fsw += symbol
75+
fsw += x
76+
fsw += "x"
77+
fsw += y
78+
79+
translations.append(fsw)
80+
81+
return translations
82+
83+
if __name__ == "__main__":
84+
spoken = "hi"
85+
symbol = translate(spoken)
86+
print(symbol)
87+
88+
from signwriting.visualizer.visualize import signwriting_to_image
89+
90+
fsw = symbol[0]
91+
img = signwriting_to_image(fsw)
92+
img.save("sign.png")

0 commit comments

Comments
 (0)