-
Notifications
You must be signed in to change notification settings - Fork 333
test(sampling_params): repair broken test collection and add verify() coverage #1350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |||||
| RegularConstraint, | ||||||
| AllowedTokenIds, | ||||||
| ExponentialDecayLengthPenalty, | ||||||
| DecodeNode, | ||||||
| SamplingParams, | ||||||
| GuidedGrammar, | ||||||
| GuidedJsonSchema, | ||||||
|
|
@@ -14,6 +13,7 @@ | |||||
| ALLOWED_TOKEN_IDS_MAX_LENGTH, | ||||||
| JSON_SCHEMA_MAX_LENGTH, | ||||||
| GRAMMAR_CONSTRAINT_MAX_LENGTH, | ||||||
| MAX_BEST_OF, | ||||||
| ) | ||||||
|
|
||||||
| grammar_str = r"""root ::= (expr "=" term)+ | ||||||
|
|
@@ -117,24 +117,6 @@ def test_exponential_decay_length_penalty_initialization(): | |||||
| penalty.initialize((5, 0.5)) | ||||||
|
|
||||||
|
|
||||||
| def test_decode_node_initialization(): | ||||||
| node = DecodeNode() | ||||||
| data = { | ||||||
| "node_id": 12345678901234567890, # 示例 UUID | ||||||
| "ip": "192.168.1.1", | ||||||
| "rpyc_port": 8080, | ||||||
| "max_new_tokens": 10, | ||||||
| } | ||||||
| node.initialize(data) | ||||||
| assert node.exists is True | ||||||
| assert node.node_id.node_id_high == (12345678901234567890 >> 64) & 0xFFFFFFFFFFFFFFFF | ||||||
| assert node.node_id.node_id_low == 12345678901234567890 & 0xFFFFFFFFFFFFFFFF | ||||||
| assert node.ip[0] == 192 | ||||||
| assert node.ip[1] == 168 | ||||||
| assert node.ip[2] == 1 | ||||||
| assert node.ip[3] == 1 | ||||||
|
|
||||||
|
|
||||||
| def test_sampling_params_initialization(): | ||||||
| params = SamplingParams() | ||||||
| data = { | ||||||
|
|
@@ -161,7 +143,6 @@ def test_sampling_params_initialization(): | |||||
| "allowed_token_ids": [1, 2, 3], | ||||||
| "stop_sequences": [[2, 1], [3, 4]], | ||||||
| "exponential_decay_length_penalty": (1, 1.0), | ||||||
| "move_kv_to_decode_node": None, | ||||||
| } | ||||||
| params.init(None, **data) | ||||||
|
|
||||||
|
|
@@ -173,6 +154,90 @@ def test_sampling_params_initialization(): | |||||
| assert params.stop_sequences.size == 2 | ||||||
|
|
||||||
|
|
||||||
| def _make_params(**overrides): | ||||||
| """Build a SamplingParams whose fields are valid by default, applying overrides. | ||||||
|
|
||||||
| ``do_sample=True`` is used so that the sampling-related fields (temperature, top_p, | ||||||
| top_k) are kept as provided; with greedy decoding ``init`` overrides them to defaults. | ||||||
| """ | ||||||
| data = { | ||||||
| "best_of": 1, | ||||||
| "n": 1, | ||||||
| "do_sample": True, | ||||||
| "presence_penalty": 0.0, | ||||||
| "frequency_penalty": 0.0, | ||||||
| "repetition_penalty": 1.0, | ||||||
| "temperature": 1.0, | ||||||
| "top_p": 1.0, | ||||||
| "top_k": -1, | ||||||
| "max_new_tokens": 16, | ||||||
| "min_new_tokens": 1, | ||||||
| } | ||||||
| data.update(overrides) | ||||||
| params = SamplingParams() | ||||||
| params.init(None, **data) | ||||||
| return params | ||||||
|
|
||||||
|
|
||||||
| def test_verify_accepts_valid_defaults(): | ||||||
| # A minimally specified, valid configuration must pass verification. | ||||||
| _make_params().verify() | ||||||
|
|
||||||
|
|
||||||
| def test_verify_accepts_n_equal_best_of_greater_than_one(): | ||||||
| params = _make_params(best_of=2, n=2) | ||||||
| params.verify() | ||||||
| assert params.n == params.best_of == 2 | ||||||
|
|
||||||
|
|
||||||
| def test_verify_rejects_n_not_equal_best_of(): | ||||||
| # The engine currently only supports n == best_of; a mismatch must be rejected. | ||||||
| with pytest.raises(ValueError): | ||||||
| _make_params(best_of=2, n=1).verify() | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("best_of", [0, -1, MAX_BEST_OF + 1]) | ||||||
| def test_verify_rejects_best_of_out_of_range(best_of): | ||||||
| with pytest.raises(ValueError): | ||||||
| _make_params(best_of=best_of, n=best_of).verify() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize( | ||||||
| "field, value", | ||||||
| [ | ||||||
| ("presence_penalty", -0.1), | ||||||
| ("frequency_penalty", -0.1), | ||||||
| ("repetition_penalty", 0.5), | ||||||
| ("temperature", -1.0), | ||||||
| ("top_p", 0.0), | ||||||
| ("top_p", 1.5), | ||||||
| ("top_k", 0), | ||||||
| ("top_k", -2), | ||||||
| ("max_new_tokens", 0), | ||||||
| ("min_new_tokens", 0), | ||||||
| ], | ||||||
| ) | ||||||
| def test_verify_rejects_invalid_sampling_fields(field, value): | ||||||
| with pytest.raises(ValueError): | ||||||
| _make_params(**{field: value}).verify() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| def test_verify_rejects_min_new_tokens_greater_than_max(): | ||||||
| with pytest.raises(ValueError): | ||||||
| _make_params(min_new_tokens=8, max_new_tokens=4).verify() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("top_k", [-1, 1, 50]) | ||||||
| def test_verify_accepts_valid_top_k(top_k): | ||||||
| _make_params(top_k=top_k).verify() | ||||||
|
|
||||||
|
|
||||||
| def test_verify_rejects_regular_constraint_with_allowed_token_ids(): | ||||||
| # regular_constraint and allowed_token_ids are mutually exclusive. | ||||||
| with pytest.raises(ValueError): | ||||||
| _make_params(regular_constraint="[a-z]+", allowed_token_ids=[1, 2, 3]).verify() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| # Mock tokenizer for testing | ||||||
| class MockTokenizer: | ||||||
| def encode(self, text, add_special_tokens=False): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
_make_paramsinternally callsparams.init(), which automatically runsself.verify(), any invalid parameter configuration will raise aValueErrorduring initialization. Consequently, the trailing.verify()call is never reached and is dead code. Removing.verify()makes the test cleaner and accurately reflects where the exception is raised.