Skip to content

Commit 6502488

Browse files
committed
feat(db): add postgres support
1 parent 3142054 commit 6502488

1 file changed

Lines changed: 88 additions & 37 deletions

File tree

db.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,64 +8,115 @@
88
class DB:
99
connector: Union[sqlite3.Connection]
1010
cursor: Union[sqlite3.Cursor, psycopg.Cursor]
11-
db_config: DatabaseConfig = field(init=False)
11+
db_config: DatabaseConfig
1212

1313
@classmethod
1414
def from_config(cls, db_config: DatabaseConfig):
15-
cls.db_config = db_config
16-
match cls.db_config.db_type:
15+
match db_config.db_type:
1716
case "sqlite":
1817
connector = sqlite3.connect(db_config.sqlite_config.db_path)
1918
case "postgres":
20-
connector = psycopg.connect("host=%s port=%d dbname=%s user=%s password=%s".format(cls.db_config.postgres_config.host,
21-
cls.db_config.postgres_config.port,
22-
cls.db_config.postgres_config.db,
23-
cls.db_config.postgres_config.user,
24-
cls.db_config.postgres_config.password))
25-
cursor = connector.cursor()
26-
27-
if "translations" not in cls.get_table_list(cursor):
28-
cls.init_table(cursor)
19+
connector = psycopg.connect("""host={}
20+
port={}
21+
dbname={}
22+
user={}
23+
password={}
24+
""".format(db_config.postgres_config.host,
25+
db_config.postgres_config.port,
26+
db_config.postgres_config.db,
27+
db_config.postgres_config.user,
28+
db_config.postgres_config.password)
29+
)
30+
cls.db_config = db_config
31+
cls.connector = connector
32+
cls.cursor = connector.cursor()
33+
if "translations" not in cls.get_table_list():
34+
cls.init_table()
2935

30-
return cls(connector, cursor)
36+
return cls(cls.connector, cls.cursor, cls.db_config)
3137

3238
@classmethod
33-
def get_table_list(cls, cursor: Union[sqlite3.Cursor, psycopg.Cursor]) -> list:
39+
def get_table_list(cls) -> list:
3440
match cls.db_config.db_type:
3541
case "sqlite":
36-
query = "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%'"
42+
query = """SELECT name FROM sqlite_master
43+
WHERE type = 'table'
44+
AND name NOT LIKE 'sqlite_%'
45+
"""
3746
case "postgres":
38-
query = "SELECT * FROM information_schema.tables"
47+
query = """SELECT tablename FROM pg_catalog.pg_tables
48+
WHERE schemaname
49+
NOT IN ('pg_catalog', 'information_schema')
50+
"""
3951

40-
cursor.execute(query)
41-
return [t[0] for t in cursor.fetchall()]
52+
cls.cursor.execute(query)
53+
return [t[0] for t in cls.cursor.fetchall()]
4254

4355
@classmethod
44-
def init_table(cls, cursor: sqlite3.Cursor):
45-
cursor.execute("""CREATE TABLE translations (
46-
src_lang TEXT,
47-
tgt_lang TEXT,
48-
src_text TEXT,
49-
tgt_text TEXT)""")
50-
cursor.connection.commit()
56+
def init_table(cls) -> None:
57+
cls.cursor.execute("""CREATE TABLE translations (
58+
src_lang TEXT,
59+
tgt_lang TEXT,
60+
src_text TEXT,
61+
tgt_text TEXT)
62+
""")
63+
cls.cursor.connection.commit()
5164

52-
def save_translation(self, src_lang:str, tgt_lang:str, src_text:str, tgt_text:str):
53-
self.cursor.execute("INSERT INTO translations VALUES (?,?,?,?)", (src_lang, tgt_lang, src_text, tgt_text))
54-
self.connector.commit()
65+
def save_translation(cls, src_lang:str, tgt_lang:str, src_text:str, tgt_text:str) -> None:
66+
query = """
67+
INSERT INTO translations
68+
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder})
69+
"""
70+
cls.cursor.execute(cls._fill_placeholder(query), (src_lang, tgt_lang, src_text, tgt_text))
71+
cls.connector.commit()
72+
73+
def fetch_translation(cls, src_lang:str, tgt_lang:str, src_text:str) -> str:
74+
query = """
75+
SELECT tgt_text FROM translations
76+
WHERE src_lang = {placeholder}
77+
AND tgt_lang = {placeholder}
78+
AND src_text = {placeholder}
79+
"""
5580

56-
def fetch_translation(self, src_lang:str , tgt_lang:str, src_text:str):
57-
self.cursor.execute("SELECT tgt_text FROM translations WHERE src_lang=? AND tgt_lang=? AND src_text=?", (src_lang, tgt_lang, src_text))
58-
result = self.cursor.fetchone()
81+
cls.cursor.execute(cls._fill_placeholder(query), (src_lang, tgt_lang, src_text))
82+
result = cls.cursor.fetchone()
5983
return result[0] if result else None
6084

61-
def delete_translation(self, src_lang:str , tgt_lang:str, src_text:str):
62-
self.cursor.execute("DELETE FROM translations WHERE src_lang=? AND tgt_lang=? AND src_text=?", (src_lang, tgt_lang, src_text))
63-
self.connector.commit()
85+
def delete_translation(cls, src_lang:str , tgt_lang:str, src_text:str) -> None:
86+
query = """
87+
DELETE FROM translations
88+
WHERE src_lang={placeholder}
89+
AND tgt_lang={placeholder}
90+
AND src_text={placeholder}
91+
"""
92+
93+
cls.cursor.execute(cls._fill_placeholder(query), (src_lang, tgt_lang, src_text))
94+
cls.connector.commit()
6495

65-
def get_latest_translations(self, src_lang: str, tgt_lang: str, index: int):
66-
self.cursor.execute(f"SELECT * FROM translations WHERE src_lang=? AND tgt_lang=? ORDER BY rowid desc LIMIT {index}", (src_lang, tgt_lang))
67-
records = self.cursor.fetchall()
96+
def get_latest_translations(cls, src_lang: str, tgt_lang: str, index: int):
97+
query = """
98+
SELECT * FROM translations
99+
WHERE src_lang = {placeholder} AND tgt_lang = {placeholder}
100+
ORDER BY {order_by} DESC
101+
LIMIT {placeholder}
102+
"""
103+
104+
cls.cursor.execute(cls._fill_placeholder(query), (src_lang, tgt_lang, index))
105+
records = cls.cursor.fetchall()
106+
68107
return [TranslationRecord(record[0], record[1], record[2], record[3]) for record in records]
108+
109+
@classmethod
110+
def _fill_placeholder(cls, target_str: str):
111+
match cls.db_config.db_type:
112+
case "sqlite":
113+
placeholder = "?"
114+
order_by = "rowid"
115+
case "postgres":
116+
placeholder = "%s"
117+
order_by = "ctid"
118+
119+
return target_str.format(placeholder=placeholder, order_by=order_by)
69120

70121
@dataclass
71122
class TranslationRecord:

0 commit comments

Comments
 (0)