Skip to content

Commit fc6748d

Browse files
✨ Add encoding parameter to write_file() (#511)
Co-authored-by: claell <26320273+claell@users.noreply.github.com>
1 parent 59a3c34 commit fc6748d

2 files changed

Lines changed: 86 additions & 3 deletions

File tree

bibtexparser/entrypoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def write_file(
139139
unparse_stack: Optional[Iterable[Middleware]] = None,
140140
prepend_middleware: Optional[Iterable[Middleware]] = None,
141141
bibtex_format: Optional[BibtexFormat] = None,
142+
encoding: str = "UTF-8",
142143
) -> None:
143144
"""Write a BibTeX database to a file.
144145
@@ -148,15 +149,16 @@ def write_file(
148149
If None, a default stack will be used.
149150
:param prepend_middleware: List of middleware to prepend to the default stack.
150151
Only applicable if `unparse_stack` is None.
151-
:param bibtex_format: Customized BibTeX format to use (optional)."""
152+
:param bibtex_format: Customized BibTeX format to use (optional).
153+
:param encoding: Encoding of the .bib file. Default encoding is ``"UTF-8"``."""
152154
bibtex_str = write_string(
153155
library=library,
154156
unparse_stack=unparse_stack,
155157
prepend_middleware=prepend_middleware,
156158
bibtex_format=bibtex_format,
157159
)
158160
if isinstance(file, str):
159-
with open(file, "w") as f:
161+
with open(file, "w", encoding=encoding) as f:
160162
f.write(bibtex_str)
161163
else:
162164
file.write(bibtex_str)

tests/test_entrypoint.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
"""Testing the parse_file function."""
1+
"""Testing the parse_file and write_file functions."""
2+
3+
import os
4+
import tempfile
25

36
from bibtexparser import parse_file
7+
from bibtexparser import write_file
8+
from bibtexparser.library import Library
9+
from bibtexparser.model import Entry
10+
from bibtexparser.model import Field
411

512

613
def test_gbk():
@@ -9,3 +16,77 @@ def test_gbk():
916
assert library.entries[0]["title"] == "Test Title"
1017
assert library.entries[0]["year"] == "2013"
1118
assert library.entries[0]["journal"] == "测试期刊"
19+
20+
21+
def test_write_file_default_encoding():
22+
"""Test write_file uses UTF-8 by default."""
23+
entry = Entry(
24+
entry_type="article",
25+
key="test2024",
26+
fields=[
27+
Field(key="author", value="Müller"),
28+
Field(key="title", value="Ångström measurements"),
29+
],
30+
)
31+
library = Library([entry])
32+
33+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bib", delete=False) as f:
34+
temp_path = f.name
35+
36+
try:
37+
write_file(temp_path, library)
38+
# Read back and verify
39+
with open(temp_path, encoding="UTF-8") as f:
40+
content = f.read()
41+
assert "Müller" in content
42+
assert "Ångström" in content
43+
finally:
44+
os.unlink(temp_path)
45+
46+
47+
def test_write_file_gbk_encoding():
48+
"""Test write_file with GBK encoding for Chinese characters."""
49+
entry = Entry(
50+
entry_type="article",
51+
key="test2024",
52+
fields=[
53+
Field(key="author", value="凯撒"),
54+
Field(key="title", value="Test Title"),
55+
Field(key="journal", value="测试期刊"),
56+
],
57+
)
58+
library = Library([entry])
59+
60+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bib", delete=False) as f:
61+
temp_path = f.name
62+
63+
try:
64+
write_file(temp_path, library, encoding="gbk")
65+
# Read back with GBK and verify
66+
with open(temp_path, encoding="gbk") as f:
67+
content = f.read()
68+
assert "凯撒" in content
69+
assert "测试期刊" in content
70+
finally:
71+
os.unlink(temp_path)
72+
73+
74+
def test_write_file_roundtrip_gbk():
75+
"""Test round-trip: parse GBK file, write with GBK, parse again."""
76+
# Parse original GBK file
77+
library = parse_file("tests/resources/gbk_test.bib", encoding="gbk")
78+
original_author = library.entries[0]["author"]
79+
original_journal = library.entries[0]["journal"]
80+
81+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bib", delete=False) as f:
82+
temp_path = f.name
83+
84+
try:
85+
# Write with GBK encoding
86+
write_file(temp_path, library, encoding="gbk")
87+
# Parse back
88+
library2 = parse_file(temp_path, encoding="gbk")
89+
assert library2.entries[0]["author"] == original_author
90+
assert library2.entries[0]["journal"] == original_journal
91+
finally:
92+
os.unlink(temp_path)

0 commit comments

Comments
 (0)