-
Notifications
You must be signed in to change notification settings - Fork 445
Expand file tree
/
Copy pathperplexity.py
More file actions
122 lines (97 loc) · 4.73 KB
/
perplexity.py
File metadata and controls
122 lines (97 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Perplexity Metric:
-------------------------------------------------------
Class for calculating perplexity from AttackResults
"""
import torch
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
import textattack.shared.utils
class Perplexity(Metric):
def __init__(self, model_name="gpt2"):
self.all_metrics = {}
self.original_candidates = []
self.successful_candidates = []
if model_name == "gpt2":
from transformers import GPT2LMHeadModel, GPT2Tokenizer
self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
self.ppl_model.to(textattack.shared.utils.device)
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.ppl_model.eval()
self.max_length = self.ppl_model.config.n_positions
else:
from transformers import AutoModelForMaskedLM, AutoTokenizer
self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name)
self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.ppl_model.to(textattack.shared.utils.device)
self.ppl_model.eval()
self.max_length = self.ppl_model.config.max_position_embeddings
self.stride = 512
def calculate(self, results):
"""Calculates average Perplexity on all successfull attacks using a
pre-trained small GPT-2 model.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results)
"""
self.results = results
self.original_candidates_ppl = []
self.successful_candidates_ppl = []
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(
result.original_result.attacked_text.text.lower()
)
self.successful_candidates.append(
result.perturbed_result.attacked_text.text.lower()
)
ppl_orig = self.calc_ppl(self.original_candidates)
ppl_attack = self.calc_ppl(self.successful_candidates)
self.all_metrics["avg_original_perplexity"] = round(ppl_orig, 2)
self.all_metrics["avg_attack_perplexity"] = round(ppl_attack, 2)
return self.all_metrics
def calc_ppl(self, texts):
with torch.no_grad():
text = " ".join(texts)
eval_loss = []
input_ids = torch.tensor(
self.ppl_tokenizer.encode(text, add_special_tokens=True)
).unsqueeze(0)
if not (input_ids_size := input_ids.size(1)):
raise ValueError("No tokens recognized for input text")
# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
for i in range(0, input_ids_size, self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, input_ids.size(1))
trg_len = end_loc - i
input_ids_t = input_ids[:, begin_loc:end_loc].to(
textattack.shared.utils.device
)
target_ids = input_ids_t.clone()
target_ids[:, :-trg_len] = -100
outputs = self.ppl_model(input_ids_t, labels=target_ids)
log_likelihood = outputs[0] * trg_len
eval_loss.append(log_likelihood)
return torch.exp(torch.stack(eval_loss).sum() / end_loc).item()