Skip to content

Commit 893b6ec

Browse files
authored
Merge pull request #107 from danmcp/automaxworkers
Allow max_workers to be passed in after evaluator is created
2 parents 39e0a7e + 52328fd commit 893b6ec

1 file changed

Lines changed: 80 additions & 30 deletions

File tree

src/instructlab/eval/mt_bench.py

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,49 @@ def __init__(
5252
self.output_dir = output_dir
5353
self.serving_gpus = serving_gpus
5454
self.merge_system_user_message = merge_system_user_message
55+
self.max_workers = self._calc_max_workers(max_workers, serving_gpus)
5556

56-
if max_workers == "auto":
57-
try:
58-
# Not available on all platforms
59-
usable_cpu_count = len(os.sched_getaffinity(0)) # type: ignore[attr-defined]
60-
except AttributeError:
61-
usable_cpu_count = multiprocessing.cpu_count()
62-
if serving_gpus is not None:
63-
# Tune max_workers based on hardware configuration: min(#GPUs being used * 10, #CPU cores)
64-
# Please see https://github.com/instructlab/instructlab/issues/2050 for detailed explanation
65-
self.max_workers = min(max(serving_gpus, 1) * 10, usable_cpu_count)
66-
logger.debug("Auto tuning max_workers to %s", self.max_workers)
57+
def _calc_max_workers(
58+
self, max_workers: int | str | None, serving_gpus: int | None
59+
):
60+
calculated_max_workers = None
61+
if max_workers is not None:
62+
if max_workers == "auto":
63+
try:
64+
# Not available on all platforms
65+
usable_cpu_count = len(os.sched_getaffinity(0)) # type: ignore[attr-defined]
66+
except AttributeError:
67+
usable_cpu_count = multiprocessing.cpu_count()
68+
if serving_gpus is not None:
69+
# Tune max_workers based on hardware configuration: min(#GPUs being used * 10, #CPU cores)
70+
# Please see https://github.com/instructlab/instructlab/issues/2050 for detailed explanation
71+
calculated_max_workers = min(
72+
max(serving_gpus, 1) * 10, usable_cpu_count
73+
)
74+
logger.debug(
75+
"Auto tuning max_workers to %s", calculated_max_workers
76+
)
77+
else:
78+
# Don't be too aggressive when serving_gpus isn't specified. Use half the cpu count.
79+
calculated_max_workers = usable_cpu_count // 2
80+
logger.debug(
81+
"max_workers set to auto but serving_gpus is not specified. Defaulting to (cpu count / 2): %s",
82+
calculated_max_workers,
83+
)
6784
else:
68-
# Don't be too aggressive when serving_gpus isn't specified. Use half the cpu count.
69-
self.max_workers = usable_cpu_count // 2
70-
logger.debug(
71-
"max_workers set to auto but serving_gpus is not specified. Defaulting to (cpu count / 2): %s",
72-
self.max_workers,
73-
)
85+
if isinstance(max_workers, int) and max_workers > 0:
86+
logger.debug("max_workers specified as: %s", max_workers)
87+
calculated_max_workers = max_workers
88+
else:
89+
raise InvalidMaxWorkersError(max_workers)
90+
return calculated_max_workers
91+
92+
def _get_effective_max_workers(self, max_workers, serving_gpus):
93+
if max_workers is not None:
94+
effective_max_workers = self._calc_max_workers(max_workers, serving_gpus)
7495
else:
75-
if isinstance(max_workers, int) and max_workers > 0:
76-
logger.debug("max_workers specified as: %s", max_workers)
77-
self.max_workers = max_workers
78-
else:
79-
raise InvalidMaxWorkersError(max_workers)
96+
effective_max_workers = self.max_workers
97+
return effective_max_workers
8098

8199

82100
class MTBenchEvaluator(AbstractMTBenchEvaluator):
@@ -94,30 +112,46 @@ class MTBenchEvaluator(AbstractMTBenchEvaluator):
94112

95113
name = "mt_bench"
96114

97-
def gen_answers(self, server_url, api_key: str | None = None) -> None:
115+
def gen_answers(
116+
self,
117+
server_url,
118+
api_key: str | None = None,
119+
max_workers: int | str | None = None,
120+
serving_gpus: int | None = None,
121+
) -> None:
98122
"""
99123
Asks questions to model
100124
101125
Attributes
102126
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
103127
api_key API token for authenticating with model server
128+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
129+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
104130
"""
105131
logger.debug(locals())
106132
mt_bench_answers.generate_answers(
107133
self.model_name,
108134
server_url,
109135
api_key=api_key,
110136
output_dir=self.output_dir,
111-
max_workers=self.max_workers,
137+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
112138
)
113139

114-
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
140+
def judge_answers(
141+
self,
142+
server_url,
143+
api_key: str | None = None,
144+
max_workers: int | str | None = None,
145+
serving_gpus: int | None = None,
146+
) -> tuple:
115147
"""
116148
Runs MT-Bench judgment
117149
118150
Attributes
119151
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
120152
api_key API token for authenticating with model server
153+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
154+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
121155
122156
Returns:
123157
overall_score MT-Bench score for the overall model evaluation
@@ -130,7 +164,7 @@ def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
130164
self.judge_model_name,
131165
server_url,
132166
api_key=api_key,
133-
max_workers=self.max_workers,
167+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
134168
output_dir=self.output_dir,
135169
merge_system_user_message=self.merge_system_user_message,
136170
)
@@ -175,13 +209,21 @@ def __init__(
175209
self.taxonomy_git_repo_path = taxonomy_git_repo_path
176210
self.branch = branch
177211

178-
def gen_answers(self, server_url, api_key: str | None = None) -> None:
212+
def gen_answers(
213+
self,
214+
server_url,
215+
api_key: str | None = None,
216+
max_workers: int | str | None = None,
217+
serving_gpus: int | None = None,
218+
) -> None:
179219
"""
180220
Asks questions to model
181221
182222
Attributes
183223
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
184224
api_key API token for authenticating with model server
225+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
226+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
185227
"""
186228
logger.debug(locals())
187229
mt_bench_branch_generator.generate(
@@ -197,17 +239,25 @@ def gen_answers(self, server_url, api_key: str | None = None) -> None:
197239
branch=self.branch,
198240
output_dir=self.output_dir,
199241
data_dir=self.output_dir,
200-
max_workers=self.max_workers,
242+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
201243
bench_name="mt_bench_branch",
202244
)
203245

204-
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
246+
def judge_answers(
247+
self,
248+
server_url,
249+
api_key: str | None = None,
250+
max_workers: int | str | None = None,
251+
serving_gpus: int | None = None,
252+
) -> tuple:
205253
"""
206254
Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name.
207255
208256
Attributes
209257
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
210258
api_key API token for authenticating with model server
259+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
260+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
211261
212262
Returns:
213263
qa_pairs Question and answer pairs (with scores) from the evaluation
@@ -219,7 +269,7 @@ def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
219269
server_url,
220270
api_key=api_key,
221271
branch=self.branch,
222-
max_workers=self.max_workers,
272+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
223273
output_dir=self.output_dir,
224274
data_dir=self.output_dir,
225275
bench_name="mt_bench_branch",

0 commit comments

Comments
 (0)