Skip to content

Commit d9d7cb2

Browse files
committed
Make sure ACES works with the latest dependencies
1 parent 57f157c commit d9d7cb2

5 files changed

Lines changed: 19 additions & 9 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ description = "ACES metric for evaluating automated audio captioning models base
88
readme = "README.md"
99
authors = [{name = "Gijs Wijngaard", email = "hi@gijs.me"}]
1010
license = {file = "LICENSE"}
11-
requires-python = ">=3.9,<3.11"
11+
requires-python = ">=3.9"
1212
dynamic = ["version", "dependencies"]
1313

1414
[tool.setuptools.dynamic]

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
transformers<4.31.0
2-
numpy~=1.24.1
3-
torch==1.13.1
4-
sentence_transformers~=2.2.2
5-
tqdm~=4.49.0
1+
transformers
2+
numpy
3+
torch
4+
sentence_transformers
5+
tqdm

src/aces/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .main import ACES, get_aces_score
22

3-
__version__ = "0.0.4"
3+
__version__ = "0.0.5"

src/aces/fense/evaluator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,17 @@ def load_pretrain_echecker(echecker_model, device='cuda', use_proxy=False, proxi
2828
file_path = check_download_resource(remote, use_proxy, proxies)
2929
model_states = torch.load(file_path)
3030
clf = BERTFlatClassifier(model_type=model_states['model_type'], num_classes=model_states['num_classes'])
31-
clf.load_state_dict(model_states['state_dict'])
31+
32+
# Handle potential incompatible keys due to transformers library version differences
33+
state_dict = model_states['state_dict']
34+
# Remove problematic keys if they exist
35+
keys_to_remove = ['encoder.embeddings.position_ids']
36+
for key in keys_to_remove:
37+
if key in state_dict:
38+
print(f"Removing incompatible key from state_dict: {key}")
39+
del state_dict[key]
40+
41+
clf.load_state_dict(state_dict, strict=False)
3242
clf.eval()
3343
clf.to(device)
3444
return clf

src/aces/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def postprocess(self, model_outputs: Dict,
109109
# Process the model outputs
110110
if isinstance(model_outputs, list):
111111
model_outputs = model_outputs[0]
112-
logits = model_outputs["logits"][0].numpy()
112+
logits = model_outputs["logits"][0].float().numpy()
113113

114114
# Calculate scores from logits
115115
maxes = np.max(logits, axis=-1, keepdims=True)

0 commit comments

Comments
 (0)