Skip to content

Commit ee2b894

Browse files
committed
Fix hash copy semantics and add tests
1 parent 17f3332 commit ee2b894

7 files changed

Lines changed: 242 additions & 44 deletions

File tree

scripts/build_ffi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def build_ffi(local_wolfssl, features):
575575
int wc_ShaUpdate(wc_Sha*, const byte*, word32);
576576
int wc_ShaFinal(wc_Sha*, byte*);
577577
void wc_ShaFree(wc_Sha*);
578+
int wc_ShaCopy(wc_Sha*, wc_Sha*);
578579
"""
579580

580581
if features["SHA256"]:
@@ -584,6 +585,7 @@ def build_ffi(local_wolfssl, features):
584585
int wc_Sha256Update(wc_Sha256*, const byte*, word32);
585586
int wc_Sha256Final(wc_Sha256*, byte*);
586587
void wc_Sha256Free(wc_Sha256*);
588+
int wc_Sha256Copy(wc_Sha256*, wc_Sha256*);
587589
"""
588590

589591
if features["SHA384"]:
@@ -593,6 +595,7 @@ def build_ffi(local_wolfssl, features):
593595
int wc_Sha384Update(wc_Sha384*, const byte*, word32);
594596
int wc_Sha384Final(wc_Sha384*, byte*);
595597
void wc_Sha384Free(wc_Sha384*);
598+
int wc_Sha384Copy(wc_Sha384*, wc_Sha384*);
596599
"""
597600

598601
if features["SHA512"]:
@@ -603,6 +606,7 @@ def build_ffi(local_wolfssl, features):
603606
int wc_Sha512Update(wc_Sha512*, const byte*, word32);
604607
int wc_Sha512Final(wc_Sha512*, byte*);
605608
void wc_Sha512Free(wc_Sha512*);
609+
int wc_Sha512Copy(wc_Sha512*, wc_Sha512*);
606610
"""
607611
if features["SHA3"]:
608612
cdef += """
@@ -623,6 +627,10 @@ def build_ffi(local_wolfssl, features):
623627
void wc_Sha3_256_Free(wc_Sha3*);
624628
void wc_Sha3_384_Free(wc_Sha3*);
625629
void wc_Sha3_512_Free(wc_Sha3*);
630+
int wc_Sha3_224_Copy(wc_Sha3*, wc_Sha3*);
631+
int wc_Sha3_256_Copy(wc_Sha3*, wc_Sha3*);
632+
int wc_Sha3_384_Copy(wc_Sha3*, wc_Sha3*);
633+
int wc_Sha3_512_Copy(wc_Sha3*, wc_Sha3*);
626634
"""
627635

628636
if features["DES3"]:

tests/test_aesgcmstream.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,31 @@ def test_encrypt_aad_bad():
126126
def test_invalid_tag_bytes():
127127
key = "fedcba9876543210"
128128
iv = "0123456789abcdef"
129-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
129+
# Out of range
130+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
130131
AesGcmStream(key, iv, tag_bytes=0)
131-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
132+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
132133
AesGcmStream(key, iv, tag_bytes=3)
133-
with pytest.raises(ValueError, match="tag_bytes must be between 4 and 16"):
134+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
134135
AesGcmStream(key, iv, tag_bytes=17)
135-
# valid edge cases
136-
AesGcmStream(key, iv, tag_bytes=4)
137-
AesGcmStream(key, iv, tag_bytes=16)
136+
# Non-NIST sizes within 4-16 range
137+
for bad in (5, 6, 7, 9, 10, 11):
138+
with pytest.raises(ValueError, match="tag_bytes must be one of"):
139+
AesGcmStream(key, iv, tag_bytes=bad)
140+
# Valid NIST sizes: verify the resulting tag has the requested length.
141+
for good in (4, 8, 12, 13, 14, 15, 16):
142+
gcm = AesGcmStream(key, iv, tag_bytes=good)
143+
gcm.encrypt("hello world")
144+
tag = gcm.final()
145+
assert len(tag) == good
146+
147+
def test_repeated_construction_destruction():
148+
import gc
149+
key = "fedcba9876543210"
150+
iv = "0123456789abcdef"
151+
for _ in range(1000):
152+
gcm = AesGcmStream(key, iv)
153+
gcm.encrypt("hello world")
154+
gcm.final()
155+
del gcm
156+
gc.collect()

tests/test_ciphers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,13 @@ def test_des3_rejects_mode_ctr():
879879
key = b"\x01\x23\x45\x67\x89\xab\xcd\xef" * 3
880880
iv = b"\xfe\xdc\xba\x98\x76\x54\x32\x10"
881881
with pytest.raises(ValueError, match="Des3 only supports MODE_CBC"):
882-
Des3(key, MODE_CTR, iv)
882+
Des3.new(key, MODE_CTR, iv)
883+
884+
def test_des3_rejects_mode_ecb():
885+
key = b"\x01\x23\x45\x67\x89\xab\xcd\xef" * 3
886+
iv = b"\xfe\xdc\xba\x98\x76\x54\x32\x10"
887+
with pytest.raises(ValueError, match="Des3 only supports MODE_CBC"):
888+
Des3.new(key, MODE_ECB, iv)
883889

884890

885891
if _lib.CHACHA_ENABLED:
@@ -898,3 +904,15 @@ def test_chacha_non_block_aligned():
898904
def test_chacha_invalid_key_length():
899905
with pytest.raises(ValueError, match="key must be"):
900906
ChaCha(b"\x00" * 20)
907+
908+
909+
if _lib.RSA_ENABLED:
910+
def test_encrypt_oaep_requires_hash_type(vectors):
911+
rsa = RsaPublic(vectors[RsaPublic].key)
912+
with pytest.raises(WolfCryptError, match="Hash type not set"):
913+
rsa.encrypt_oaep(b"plaintext")
914+
915+
def test_decrypt_oaep_requires_hash_type(vectors):
916+
rsa = RsaPrivate(vectors[RsaPrivate].key)
917+
with pytest.raises(WolfCryptError, match="Hash type not set"):
918+
rsa.decrypt_oaep(b"\x00" * rsa.output_size)

tests/test_hashes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,27 @@ def test_hash(hash_cls, vectors):
184184
copy.update("wolfcrypt")
185185

186186
assert hash_obj.hexdigest() == copy.hexdigest() == digest
187+
188+
189+
def test_hash_repeated_construction_destruction(hash_cls, vectors):
190+
import gc
191+
digest = vectors[hash_cls].digest
192+
for _ in range(1000):
193+
h = hash_new(hash_cls, "wolfcrypt")
194+
assert h.hexdigest() == digest
195+
del h
196+
gc.collect()
197+
198+
199+
def test_hash_copy_destroy_lifecycle(hash_cls, vectors):
200+
import gc
201+
digest = vectors[hash_cls].digest
202+
for _ in range(100):
203+
h = hash_new(hash_cls, "wolfcrypt")
204+
c = h.copy()
205+
# Destroy original first, then verify copy still produces correct digest.
206+
del h
207+
gc.collect()
208+
assert c.hexdigest() == digest
209+
del c
210+
gc.collect()

wolfcrypt/asn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ def pem_to_der(pem, pem_type):
4444
err = "Error converting from PEM to DER. ({})".format(ret)
4545
raise WolfCryptError(err)
4646

47-
result = _ffi.buffer(der[0][0].buffer, der[0][0].length)[:]
48-
_lib.wc_FreeDer(der)
47+
try:
48+
result = _ffi.buffer(der[0][0].buffer, der[0][0].length)[:]
49+
finally:
50+
_lib.wc_FreeDer(der)
4951
return result
5052

5153
def der_to_pem(der, pem_type):

wolfcrypt/ciphers.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -396,33 +396,40 @@ class AesGcmStream(object):
396396
block_size = 16
397397
_key_sizes = [16, 24, 32]
398398
_native_type = "Aes *"
399-
_aad = bytes()
400-
_tag_bytes = 16
401-
_mode = None
399+
# making sure _lib.wc_AesFree outlives Aes instances
400+
_delete = _lib.wc_AesFree
402401

403402
def __init__(self, key, IV, tag_bytes=16):
404403
"""
405404
tag_bytes is the number of bytes to use for the authentication tag during encryption
406405
"""
407406
key = t2b(key)
408407
IV = t2b(IV)
409-
if tag_bytes < 4 or tag_bytes > 16:
410-
raise ValueError("tag_bytes must be between 4 and 16")
408+
# NIST SP 800-38D valid GCM tag lengths: 16, 15, 14, 13, 12, 8, 4 bytes.
409+
if tag_bytes not in (4, 8, 12, 13, 14, 15, 16):
410+
raise ValueError(
411+
"tag_bytes must be one of 4, 8, 12, 13, 14, 15, or 16")
412+
# Per-instance state: AAD, tag length, and current mode (enc/dec).
413+
self._aad = bytes()
411414
self._tag_bytes = tag_bytes
415+
self._mode = None
412416
if len(key) not in self._key_sizes:
413417
raise ValueError("key must be %s in length, not %d" %
414418
(self._key_sizes, len(key)))
419+
self._init_done = False
415420
self._native_object = _ffi.new(self._native_type)
416421
ret = _lib.wc_AesInit(self._native_object, _ffi.NULL, -2)
417422
if ret < 0:
418423
raise WolfCryptError("AES init error (%d)" % ret)
424+
self._init_done = True
419425
ret = _lib.wc_AesGcmInit(self._native_object, key, len(key), IV, len(IV))
420426
if ret < 0:
421427
raise WolfCryptError("Init error (%d)" % ret)
422428

423429
def __del__(self):
424-
if hasattr(self, '_native_object'):
425-
_lib.wc_AesFree(self._native_object)
430+
if getattr(self, '_init_done', False):
431+
self._delete(self._native_object)
432+
self._init_done = False
426433

427434
def set_aad(self, data):
428435
"""
@@ -446,11 +453,11 @@ def encrypt(self, data):
446453
aad = self._aad
447454
elif self._mode == _DECRYPTION:
448455
raise WolfCryptError("Class instance already in use for decryption")
449-
self._buf = _ffi.new("byte[%d]" % (len(data)))
450-
ret = _lib.wc_AesGcmEncryptUpdate(self._native_object, self._buf, data, len(data), aad, len(aad))
456+
buf = _ffi.new("byte[%d]" % (len(data)))
457+
ret = _lib.wc_AesGcmEncryptUpdate(self._native_object, buf, data, len(data), aad, len(aad))
451458
if ret < 0:
452459
raise WolfCryptError("Encryption error (%d)" % ret)
453-
return bytes(self._buf)
460+
return bytes(buf)
454461

455462
def decrypt(self, data):
456463
"""
@@ -463,11 +470,11 @@ def decrypt(self, data):
463470
aad = self._aad
464471
elif self._mode == _ENCRYPTION:
465472
raise WolfCryptError("Class instance already in use for encryption")
466-
self._buf = _ffi.new("byte[%d]" % (len(data)))
467-
ret = _lib.wc_AesGcmDecryptUpdate(self._native_object, self._buf, data, len(data), aad, len(aad))
473+
buf = _ffi.new("byte[%d]" % (len(data)))
474+
ret = _lib.wc_AesGcmDecryptUpdate(self._native_object, buf, data, len(data), aad, len(aad))
468475
if ret < 0:
469476
raise WolfCryptError("Decryption error (%d)" % ret)
470-
return bytes(self._buf)
477+
return bytes(buf)
471478

472479
def final(self, authTag=None):
473480
"""
@@ -505,7 +512,9 @@ class ChaCha(_Cipher):
505512
_IV_nonce = b""
506513
_IV_counter = 0
507514

508-
def __init__(self, key="", size=32):
515+
def __init__(self, key="", size=32): # pylint: disable=unused-argument
516+
# size is kept for backwards compatibility; key length is now
517+
# derived from the actual key and validated against _key_sizes.
509518
self._native_object = _ffi.new(self._native_type)
510519
self._enc = None
511520
self._dec = None
@@ -552,7 +561,9 @@ def set_iv(self, nonce, counter = 0):
552561
raise ValueError("nonce must be %d bytes, got %d" %
553562
(self._NONCE_SIZE, len(self._IV_nonce)))
554563
self._IV_counter = counter
555-
self._set_key(0)
564+
ret = self._set_key(0)
565+
if ret < 0:
566+
raise WolfCryptError("ChaCha set_iv error (%d)" % ret)
556567

557568
if _lib.CHACHA20_POLY1305_ENABLED:
558569
class ChaCha20Poly1305(object):
@@ -643,6 +654,9 @@ class Des3(_Cipher):
643654
_native_type = "Des3 *"
644655

645656
def __init__(self, key, mode, IV=None):
657+
# Intentionally stricter than _Cipher.__init__, which accepts both
658+
# CBC and CTR. wolfCrypt has no 3DES-CTR implementation, so reject
659+
# MODE_CTR here with a clearer error before delegating.
646660
if mode != MODE_CBC:
647661
raise ValueError("Des3 only supports MODE_CBC")
648662
super().__init__(key, mode, IV)
@@ -864,6 +878,9 @@ def make_key(cls, size, rng=None, hash_type=None):
864878
if rsa.output_size <= 0: # pragma: no cover
865879
raise WolfCryptError("Invalid key size error (%d)" % ret)
866880

881+
# Retain RNG reference defensively.
882+
rsa._rng = rng
883+
867884
return rsa
868885

869886
def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-called
@@ -1231,7 +1248,11 @@ def make_key(cls, size, rng=None):
12311248
ret = _lib.wc_ecc_set_rng(ecc.native_object, rng.native_object)
12321249
if ret < 0:
12331250
raise WolfCryptError("Error setting ECC RNG (%d)" % ret)
1234-
ecc._rng = rng
1251+
1252+
# Retain the RNG so it outlives the ECC key. Even outside the
1253+
# timing-resistance path, wolfSSL internals may retain a pointer
1254+
# to the RNG; keeping the reference avoids any UAF risk.
1255+
ecc._rng = rng
12351256

12361257
return ecc
12371258

@@ -1504,6 +1525,10 @@ def make_key(cls, size, rng=None):
15041525
if ret < 0:
15051526
raise WolfCryptError("Key generation error (%d)" % ret)
15061527

1528+
# Retain RNG reference defensively; wolfSSL may retain a pointer
1529+
# internally on some builds.
1530+
ed25519._rng = rng
1531+
15071532
return ed25519
15081533

15091534
def decode_key(self, key, pub = None):
@@ -1706,6 +1731,10 @@ def make_key(cls, size, rng=None):
17061731
if ret < 0:
17071732
raise WolfCryptError("Key generation error (%d)" % ret)
17081733

1734+
# Retain RNG reference defensively; wolfSSL may retain a pointer
1735+
# internally on some builds.
1736+
ed448._rng = rng
1737+
17091738
return ed448
17101739

17111740
def decode_key(self, key, pub = None):
@@ -1979,6 +2008,9 @@ def make_key(cls, mlkem_type, rng=None):
19792008
if ret < 0: # pragma: no cover
19802009
raise WolfCryptError("wc_KyberKey_MakeKey() error (%d)" % ret)
19812010

2011+
# Retain RNG reference defensively.
2012+
mlkem_priv._rng = rng
2013+
19822014
return mlkem_priv
19832015

19842016
@classmethod
@@ -2226,6 +2258,9 @@ def make_key(cls, mldsa_type, rng=None):
22262258
if ret < 0: # pragma: no cover
22272259
raise WolfCryptError("wc_dilithium_make_key() error (%d)" % ret)
22282260

2261+
# Retain RNG reference defensively.
2262+
mldsa_priv._rng = rng
2263+
22292264
return mldsa_priv
22302265

22312266
@property

0 commit comments

Comments
 (0)