@@ -52,31 +52,49 @@ def __init__(
5252 self .output_dir = output_dir
5353 self .serving_gpus = serving_gpus
5454 self .merge_system_user_message = merge_system_user_message
55+ self .max_workers = self ._calc_max_workers (max_workers , serving_gpus )
5556
56- if max_workers == "auto" :
57- try :
58- # Not available on all platforms
59- usable_cpu_count = len (os .sched_getaffinity (0 )) # type: ignore[attr-defined]
60- except AttributeError :
61- usable_cpu_count = multiprocessing .cpu_count ()
62- if serving_gpus is not None :
63- # Tune max_workers based on hardware configuration: min(#GPUs being used * 10, #CPU cores)
64- # Please see https://github.com/instructlab/instructlab/issues/2050 for detailed explanation
65- self .max_workers = min (max (serving_gpus , 1 ) * 10 , usable_cpu_count )
66- logger .debug ("Auto tuning max_workers to %s" , self .max_workers )
57+ def _calc_max_workers (
58+ self , max_workers : int | str | None , serving_gpus : int | None
59+ ):
60+ calculated_max_workers = None
61+ if max_workers is not None :
62+ if max_workers == "auto" :
63+ try :
64+ # Not available on all platforms
65+ usable_cpu_count = len (os .sched_getaffinity (0 )) # type: ignore[attr-defined]
66+ except AttributeError :
67+ usable_cpu_count = multiprocessing .cpu_count ()
68+ if serving_gpus is not None :
69+ # Tune max_workers based on hardware configuration: min(#GPUs being used * 10, #CPU cores)
70+ # Please see https://github.com/instructlab/instructlab/issues/2050 for detailed explanation
71+ calculated_max_workers = min (
72+ max (serving_gpus , 1 ) * 10 , usable_cpu_count
73+ )
74+ logger .debug (
75+ "Auto tuning max_workers to %s" , calculated_max_workers
76+ )
77+ else :
78+ # Don't be too aggressive when serving_gpus isn't specified. Use half the cpu count.
79+ calculated_max_workers = usable_cpu_count // 2
80+ logger .debug (
81+ "max_workers set to auto but serving_gpus is not specified. Defaulting to (cpu count / 2): %s" ,
82+ calculated_max_workers ,
83+ )
6784 else :
68- # Don't be too aggressive when serving_gpus isn't specified. Use half the cpu count.
69- self .max_workers = usable_cpu_count // 2
70- logger .debug (
71- "max_workers set to auto but serving_gpus is not specified. Defaulting to (cpu count / 2): %s" ,
72- self .max_workers ,
73- )
85+ if isinstance (max_workers , int ) and max_workers > 0 :
86+ logger .debug ("max_workers specified as: %s" , max_workers )
87+ calculated_max_workers = max_workers
88+ else :
89+ raise InvalidMaxWorkersError (max_workers )
90+ return calculated_max_workers
91+
92+ def _get_effective_max_workers (self , max_workers , serving_gpus ):
93+ if max_workers is not None :
94+ effective_max_workers = self ._calc_max_workers (max_workers , serving_gpus )
7495 else :
75- if isinstance (max_workers , int ) and max_workers > 0 :
76- logger .debug ("max_workers specified as: %s" , max_workers )
77- self .max_workers = max_workers
78- else :
79- raise InvalidMaxWorkersError (max_workers )
96+ effective_max_workers = self .max_workers
97+ return effective_max_workers
8098
8199
82100class MTBenchEvaluator (AbstractMTBenchEvaluator ):
@@ -94,30 +112,46 @@ class MTBenchEvaluator(AbstractMTBenchEvaluator):
94112
95113 name = "mt_bench"
96114
97- def gen_answers (self , server_url , api_key : str | None = None ) -> None :
115+ def gen_answers (
116+ self ,
117+ server_url ,
118+ api_key : str | None = None ,
119+ max_workers : int | str | None = None ,
120+ serving_gpus : int | None = None ,
121+ ) -> None :
98122 """
99123 Asks questions to model
100124
101125 Attributes
102126 server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
103127 api_key API token for authenticating with model server
128+ max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
129+ serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
104130 """
105131 logger .debug (locals ())
106132 mt_bench_answers .generate_answers (
107133 self .model_name ,
108134 server_url ,
109135 api_key = api_key ,
110136 output_dir = self .output_dir ,
111- max_workers = self .max_workers ,
137+ max_workers = self ._get_effective_max_workers ( max_workers , serving_gpus ) ,
112138 )
113139
114- def judge_answers (self , server_url , api_key : str | None = None ) -> tuple :
140+ def judge_answers (
141+ self ,
142+ server_url ,
143+ api_key : str | None = None ,
144+ max_workers : int | str | None = None ,
145+ serving_gpus : int | None = None ,
146+ ) -> tuple :
115147 """
116148 Runs MT-Bench judgment
117149
118150 Attributes
119151 server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
120152 api_key API token for authenticating with model server
153+ max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
154+ serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
121155
122156 Returns:
123157 overall_score MT-Bench score for the overall model evaluation
@@ -130,7 +164,7 @@ def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
130164 self .judge_model_name ,
131165 server_url ,
132166 api_key = api_key ,
133- max_workers = self .max_workers ,
167+ max_workers = self ._get_effective_max_workers ( max_workers , serving_gpus ) ,
134168 output_dir = self .output_dir ,
135169 merge_system_user_message = self .merge_system_user_message ,
136170 )
@@ -175,13 +209,21 @@ def __init__(
175209 self .taxonomy_git_repo_path = taxonomy_git_repo_path
176210 self .branch = branch
177211
178- def gen_answers (self , server_url , api_key : str | None = None ) -> None :
212+ def gen_answers (
213+ self ,
214+ server_url ,
215+ api_key : str | None = None ,
216+ max_workers : int | str | None = None ,
217+ serving_gpus : int | None = None ,
218+ ) -> None :
179219 """
180220 Asks questions to model
181221
182222 Attributes
183223 server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
184224 api_key API token for authenticating with model server
225+ max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
226+ serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
185227 """
186228 logger .debug (locals ())
187229 mt_bench_branch_generator .generate (
@@ -197,17 +239,25 @@ def gen_answers(self, server_url, api_key: str | None = None) -> None:
197239 branch = self .branch ,
198240 output_dir = self .output_dir ,
199241 data_dir = self .output_dir ,
200- max_workers = self .max_workers ,
242+ max_workers = self ._get_effective_max_workers ( max_workers , serving_gpus ) ,
201243 bench_name = "mt_bench_branch" ,
202244 )
203245
204- def judge_answers (self , server_url , api_key : str | None = None ) -> tuple :
246+ def judge_answers (
247+ self ,
248+ server_url ,
249+ api_key : str | None = None ,
250+ max_workers : int | str | None = None ,
251+ serving_gpus : int | None = None ,
252+ ) -> tuple :
205253 """
206254 Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name.
207255
208256 Attributes
209257 server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
210258 api_key API token for authenticating with model server
259+ max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
260+ serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
211261
212262 Returns:
213263 qa_pairs Question and answer pairs (with scores) from the evaluation
@@ -219,7 +269,7 @@ def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
219269 server_url ,
220270 api_key = api_key ,
221271 branch = self .branch ,
222- max_workers = self .max_workers ,
272+ max_workers = self ._get_effective_max_workers ( max_workers , serving_gpus ) ,
223273 output_dir = self .output_dir ,
224274 data_dir = self .output_dir ,
225275 bench_name = "mt_bench_branch" ,
0 commit comments