Skip to content

Commit 52328fd

Browse files
author
Dan McPherson
committed
Allow max_workers to be passed in after evaluator is created
The reason this is necessary is serving_gpus might not be available when the evaluator is constructed. A typical flow might be: - Create evaluator - Launch server # This is when serving_gpus is probably calculated - generate - judge The new logic allows for max_workers and serving_gpus to be passed in as the generate and judge steps occur. Once all the known callers have been updated (eval and training) we can remove the two attrs from the constructors and make the logic a little simpler. Signed-off-by: Dan McPherson <dmcphers@redhat.com>
1 parent 83f9d95 commit 52328fd

1 file changed

Lines changed: 80 additions & 30 deletions

File tree

src/instructlab/eval/mt_bench.py

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,49 @@ def __init__(
5252
self.output_dir = output_dir
5353
self.serving_gpus = serving_gpus
5454
self.merge_system_user_message = merge_system_user_message
55+
self.max_workers = self._calc_max_workers(max_workers, serving_gpus)
5556

56-
if max_workers == "auto":
57-
try:
58-
# Not available on all platforms
59-
usable_cpu_count = len(os.sched_getaffinity(0)) # type: ignore[attr-defined]
60-
except AttributeError:
61-
usable_cpu_count = multiprocessing.cpu_count()
62-
if serving_gpus is not None:
63-
# Tune max_workers based on hardware configuration: min(#GPUs being used * 10, #CPU cores)
64-
# Please see https://github.com/instructlab/instructlab/issues/2050 for detailed explanation
65-
self.max_workers = min(max(serving_gpus, 1) * 10, usable_cpu_count)
66-
logger.debug("Auto tuning max_workers to %s", self.max_workers)
57+
def _calc_max_workers(
58+
self, max_workers: int | str | None, serving_gpus: int | None
59+
):
60+
calculated_max_workers = None
61+
if max_workers is not None:
62+
if max_workers == "auto":
63+
try:
64+
# Not available on all platforms
65+
usable_cpu_count = len(os.sched_getaffinity(0)) # type: ignore[attr-defined]
66+
except AttributeError:
67+
usable_cpu_count = multiprocessing.cpu_count()
68+
if serving_gpus is not None:
69+
# Tune max_workers based on hardware configuration: min(#GPUs being used * 10, #CPU cores)
70+
# Please see https://github.com/instructlab/instructlab/issues/2050 for detailed explanation
71+
calculated_max_workers = min(
72+
max(serving_gpus, 1) * 10, usable_cpu_count
73+
)
74+
logger.debug(
75+
"Auto tuning max_workers to %s", calculated_max_workers
76+
)
77+
else:
78+
# Don't be too aggressive when serving_gpus isn't specified. Use half the cpu count.
79+
calculated_max_workers = usable_cpu_count // 2
80+
logger.debug(
81+
"max_workers set to auto but serving_gpus is not specified. Defaulting to (cpu count / 2): %s",
82+
calculated_max_workers,
83+
)
6784
else:
68-
# Don't be too aggressive when serving_gpus isn't specified. Use half the cpu count.
69-
self.max_workers = usable_cpu_count // 2
70-
logger.debug(
71-
"max_workers set to auto but serving_gpus is not specified. Defaulting to (cpu count / 2): %s",
72-
self.max_workers,
73-
)
85+
if isinstance(max_workers, int) and max_workers > 0:
86+
logger.debug("max_workers specified as: %s", max_workers)
87+
calculated_max_workers = max_workers
88+
else:
89+
raise InvalidMaxWorkersError(max_workers)
90+
return calculated_max_workers
91+
92+
def _get_effective_max_workers(self, max_workers, serving_gpus):
93+
if max_workers is not None:
94+
effective_max_workers = self._calc_max_workers(max_workers, serving_gpus)
7495
else:
75-
if isinstance(max_workers, int) and max_workers > 0:
76-
logger.debug("max_workers specified as: %s", max_workers)
77-
self.max_workers = max_workers
78-
else:
79-
raise InvalidMaxWorkersError(max_workers)
96+
effective_max_workers = self.max_workers
97+
return effective_max_workers
8098

8199

82100
class MTBenchEvaluator(AbstractMTBenchEvaluator):
@@ -94,30 +112,46 @@ class MTBenchEvaluator(AbstractMTBenchEvaluator):
94112

95113
name = "mt_bench"
96114

97-
def gen_answers(self, server_url, api_key: str | None = None) -> None:
115+
def gen_answers(
116+
self,
117+
server_url,
118+
api_key: str | None = None,
119+
max_workers: int | str | None = None,
120+
serving_gpus: int | None = None,
121+
) -> None:
98122
"""
99123
Asks questions to model
100124
101125
Attributes
102126
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
103127
api_key API token for authenticating with model server
128+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
129+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
104130
"""
105131
logger.debug(locals())
106132
mt_bench_answers.generate_answers(
107133
self.model_name,
108134
server_url,
109135
api_key=api_key,
110136
output_dir=self.output_dir,
111-
max_workers=self.max_workers,
137+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
112138
)
113139

114-
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
140+
def judge_answers(
141+
self,
142+
server_url,
143+
api_key: str | None = None,
144+
max_workers: int | str | None = None,
145+
serving_gpus: int | None = None,
146+
) -> tuple:
115147
"""
116148
Runs MT-Bench judgment
117149
118150
Attributes
119151
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
120152
api_key API token for authenticating with model server
153+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
154+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
121155
122156
Returns:
123157
overall_score MT-Bench score for the overall model evaluation
@@ -130,7 +164,7 @@ def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
130164
self.judge_model_name,
131165
server_url,
132166
api_key=api_key,
133-
max_workers=self.max_workers,
167+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
134168
output_dir=self.output_dir,
135169
merge_system_user_message=self.merge_system_user_message,
136170
)
@@ -175,13 +209,21 @@ def __init__(
175209
self.taxonomy_git_repo_path = taxonomy_git_repo_path
176210
self.branch = branch
177211

178-
def gen_answers(self, server_url, api_key: str | None = None) -> None:
212+
def gen_answers(
213+
self,
214+
server_url,
215+
api_key: str | None = None,
216+
max_workers: int | str | None = None,
217+
serving_gpus: int | None = None,
218+
) -> None:
179219
"""
180220
Asks questions to model
181221
182222
Attributes
183223
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
184224
api_key API token for authenticating with model server
225+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
226+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
185227
"""
186228
logger.debug(locals())
187229
mt_bench_branch_generator.generate(
@@ -197,17 +239,25 @@ def gen_answers(self, server_url, api_key: str | None = None) -> None:
197239
branch=self.branch,
198240
output_dir=self.output_dir,
199241
data_dir=self.output_dir,
200-
max_workers=self.max_workers,
242+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
201243
bench_name="mt_bench_branch",
202244
)
203245

204-
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
246+
def judge_answers(
247+
self,
248+
server_url,
249+
api_key: str | None = None,
250+
max_workers: int | str | None = None,
251+
serving_gpus: int | None = None,
252+
) -> tuple:
205253
"""
206254
Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name.
207255
208256
Attributes
209257
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
210258
api_key API token for authenticating with model server
259+
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
260+
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
211261
212262
Returns:
213263
qa_pairs Question and answer pairs (with scores) from the evaluation
@@ -219,7 +269,7 @@ def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
219269
server_url,
220270
api_key=api_key,
221271
branch=self.branch,
222-
max_workers=self.max_workers,
272+
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
223273
output_dir=self.output_dir,
224274
data_dir=self.output_dir,
225275
bench_name="mt_bench_branch",

0 commit comments

Comments
 (0)