1111from ..helpers .jupyter import display_table
1212import string
1313import random
14- from extract_tables import extract_tables
14+ from axcell . data . extract_tables import extract_tables
1515
1616
1717class Paper :
@@ -75,23 +75,32 @@ def _load_tables(path, annotations, jobs, migrate):
7575 return {f .parent .name : tbls for f , tbls in zip (files , tables )}
7676
7777
78+ def _gql_dump_to_annotations (dump ):
79+ annotations = {remove_arxiv_version (a .arxiv_id ): a for a in dump }
80+ annotations .update ({a .arxiv_id : a for a in dump })
81+ return annotations
82+
7883def _load_annotated_papers (data_or_path ):
79- if isinstance (data_or_path , dict ):
84+ if isinstance (data_or_path , dict ) or isinstance ( data_or_path , list ) :
8085 compressed = False
8186 else :
8287 compressed = data_or_path .suffix == ".gz"
8388 dump = load_gql_dump (data_or_path , compressed = compressed )["allPapers" ]
84- annotations = {remove_arxiv_version (a .arxiv_id ): a for a in dump }
85- annotations .update ({a .arxiv_id : a for a in dump })
86- return annotations
89+ return _gql_dump_to_annotations (dump )
8790
8891
8992class PaperCollection (UserList ):
9093 def __init__ (self , data = None ):
9194 super ().__init__ (data )
9295
9396 @classmethod
94- def from_files (cls , path , annotations_path = None , load_texts = True , load_tables = True , load_annotations = True , jobs = - 1 , migrate = False ):
97+ def from_files (cls , path , annotations = None , load_texts = True , load_tables = True , jobs = - 1 ):
98+ return cls ._from_files (path , annotations = annotations , annotations_path = None ,
99+ load_texts = load_texts , load_tables = load_tables , load_annotations = False ,
100+ jobs = jobs )
101+
102+ @classmethod
103+ def _from_files (cls , path , annotations = None , annotations_path = None , load_texts = True , load_tables = True , load_annotations = True , jobs = - 1 , migrate = False ):
95104 path = Path (path )
96105 if annotations_path is None :
97106 annotations_path = path / "structure-annotations.json"
@@ -102,7 +111,10 @@ def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=Tr
102111 else :
103112 texts = {}
104113
105- annotations = {}
114+ if annotations is None :
115+ annotations = {}
116+ else :
117+ annotations = _load_annotated_papers (annotations )
106118 if load_tables :
107119 if load_annotations :
108120 annotations = _load_annotated_papers (annotations_path )
@@ -131,8 +143,9 @@ def get_by_id(self, paper_id, ignore_version=True):
131143 def cells_gold_tags_legend (cls ):
132144 tags = [
133145 ("Tag" , "description" ),
134- ("model-best" , "model that has results that author most likely would like to have exposed" ),
135- ("model-paper" , "an example of a generic model, (like LSTM)" ),
146+ ("model-best" , "the best performing model introduced in the paper" ),
147+ ("model-paper" , "model introduced in the paper" ),
148+ ("model-ensemble" , "ensemble of models introduced in the paper" ),
136149 ("model-competing" , "model from another paper used for comparison" ),
137150 ("dataset-task" , "Task" ),
138151 ("dataset" , "Dataset" ),
0 commit comments