1+ import os
2+ import json
3+ import argparse
4+ import numpy as np
5+
6+ from metrics import (
7+ qa_f1_score ,
8+ rouge_zh_score ,
9+ qa_f1_zh_score ,
10+ rouge_score ,
11+ classification_score ,
12+ retrieval_score ,
13+ retrieval_zh_score ,
14+ count_score ,
15+ code_sim_score ,
16+ )
17+
18+ dataset2metric = {
19+ "narrativeqa" : qa_f1_score ,
20+ "qasper" : qa_f1_score ,
21+ "multifieldqa_en" : qa_f1_score ,
22+ "multifieldqa_zh" : qa_f1_zh_score ,
23+ "hotpotqa" : qa_f1_score ,
24+ "2wikimqa" : qa_f1_score ,
25+ "musique" : qa_f1_score ,
26+ "dureader" : rouge_zh_score ,
27+ "gov_report" : rouge_score ,
28+ "qmsum" : rouge_score ,
29+ "multi_news" : rouge_score ,
30+ "vcsum" : rouge_zh_score ,
31+ "trec" : classification_score ,
32+ "triviaqa" : qa_f1_score ,
33+ "samsum" : rouge_score ,
34+ "lsht" : classification_score ,
35+ "passage_retrieval_en" : retrieval_score ,
36+ "passage_count" : count_score ,
37+ "passage_retrieval_zh" : retrieval_zh_score ,
38+ "lcc" : code_sim_score ,
39+ "repobench-p" : code_sim_score ,
40+ }
41+
42+ def parse_args (args = None ):
43+ parser = argparse .ArgumentParser ()
44+ parser .add_argument ('--model' , type = str , default = None )
45+ parser .add_argument ('--e' , action = 'store_true' , help = "Evaluate on LongBench-E" )
46+ return parser .parse_args (args )
47+
48+ def scorer_e (dataset , predictions , answers , lengths , all_classes ):
49+ scores = {"0-4k" : [], "4-8k" : [], "8k+" : []}
50+ for (prediction , ground_truths , length ) in zip (predictions , answers , lengths ):
51+ score = 0.
52+ if dataset in ["trec" , "triviaqa" , "samsum" , "lsht" ]:
53+ prediction = prediction .lstrip ('\n ' ).split ('\n ' )[0 ]
54+ for ground_truth in ground_truths :
55+ score = max (score , dataset2metric [dataset ](prediction , ground_truth , all_classes = all_classes ))
56+ if length < 4000 :
57+ scores ["0-4k" ].append (score )
58+ elif length < 8000 :
59+ scores ["4-8k" ].append (score )
60+ else :
61+ scores ["8k+" ].append (score )
62+ for key in scores .keys ():
63+ scores [key ] = round (100 * np .mean (scores [key ]), 2 )
64+ return scores
65+
66+ def scorer (dataset , predictions , answers , all_classes ):
67+ total_score = 0.
68+ for (prediction , ground_truths ) in zip (predictions , answers ):
69+ score = 0.
70+ if dataset in ["trec" , "triviaqa" , "samsum" , "lsht" ]:
71+ prediction = prediction .lstrip ('\n ' ).split ('\n ' )[0 ]
72+ for ground_truth in ground_truths :
73+ score = max (score , dataset2metric [dataset ](prediction , ground_truth , all_classes = all_classes ))
74+ total_score += score
75+ return round (100 * total_score / len (predictions ), 2 )
76+
77+ if __name__ == '__main__' :
78+ args = parse_args ()
79+ scores = dict ()
80+ if args .e :
81+ path = f"pred_e/{ args .model } /"
82+ else :
83+ path = f"pred_e/{ args .model } /"
84+ all_files = os .listdir (path )
85+ print ("Evaluating on:" , all_files )
86+ for filename in all_files :
87+ if not filename .endswith ("jsonl" ):
88+ continue
89+ predictions , answers , lengths = [], [], []
90+ dataset = filename .split ('.' )[0 ]
91+ with open (f"{ path } { filename } " , "r" , encoding = "utf-8" ) as f :
92+ for line in f :
93+ data = json .loads (line )
94+ predictions .append (data ["pred" ])
95+ answers .append (data ["answers" ])
96+ all_classes = data ["all_classes" ]
97+ if "length" in data :
98+ lengths .append (data ["length" ])
99+ if args .e :
100+ score = scorer_e (dataset , predictions , answers , lengths , all_classes )
101+ else :
102+ score = scorer (dataset , predictions , answers , all_classes )
103+ if dataset == 'qasper' :
104+ score_e = scorer_e (dataset , predictions , answers , lengths , all_classes )
105+ scores [dataset ] = score
106+ # if dataset == 'qasper':
107+ # scores[dataset + '_e'] = score_e
108+ if args .e :
109+ out_path = f"H2O/results/{ args .model } /result.json"
110+ else :
111+ out_path = f"H2O/results/{ args .model } /result.json"
112+ # out_path_e = f"pred/{args.model}/result_e.json"
113+ # with open(out_path_e, "w") as f:
114+ # json.dump(score_e, f, ensure_ascii=False, indent=4)
115+ with open (out_path , "w" ) as f :
116+ json .dump (scores , f , ensure_ascii = False , indent = 4 )
0 commit comments