|
4 | 4 |
|
5 | 5 | """ |
6 | 6 |
|
| 7 | +import sys |
7 | 8 | import unittest |
8 | | -from sshpubkeys import SSHKey |
| 9 | + |
| 10 | +from sshpubkeys import AuthorizedKeysFile, InvalidOptionsError, SSHKey |
| 11 | + |
| 12 | +from .authorized_keys import items as list_of_authorized_keys |
| 13 | +from .invalid_authorized_keys import items as list_of_invalid_authorized_keys |
| 14 | +from .invalid_keys import keys as list_of_invalid_keys |
| 15 | +from .invalid_options import options as list_of_invalid_options |
9 | 16 | from .valid_keys import keys as list_of_valid_keys |
10 | 17 | from .valid_keys_rfc4716 import keys as list_of_valid_keys_rfc4716 |
11 | | -from .invalid_keys import keys as list_of_invalid_keys |
12 | 18 | from .valid_options import options as list_of_valid_options |
13 | | -from .invalid_options import options as list_of_invalid_options |
| 19 | + |
| 20 | +if sys.version_info.major == 2: |
| 21 | + from io import BytesIO as StringIO |
| 22 | +else: |
| 23 | + from io import StringIO |
| 24 | + |
| 25 | + |
| 26 | +DEFAULT_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEGODBKRjsFB/1v3pDRGpA6xR+QpOJg9vat0brlbUNDD" |
14 | 27 |
|
15 | 28 |
|
16 | 29 | class TestMisc(unittest.TestCase): |
@@ -52,6 +65,32 @@ def check_invalid_option(self, option, expected_error): |
52 | 65 | ssh = SSHKey() |
53 | 66 | self.assertRaises(expected_error, ssh.parse_options, option) |
54 | 67 |
|
| 68 | + def test_disallow_options(self): |
| 69 | + ssh = SSHKey(disallow_options=True) |
| 70 | + key = """command="dump /home",no-pty,no-port-forwarding """ + DEFAULT_KEY |
| 71 | + self.assertRaises(InvalidOptionsError, ssh.parse, key) |
| 72 | + |
| 73 | + |
| 74 | +class TestAuthorizedKeys(unittest.TestCase): |
| 75 | + |
| 76 | + def check_valid_file(self, file_str, valid_keys_count): |
| 77 | + file_obj = StringIO(file_str) |
| 78 | + key_file = AuthorizedKeysFile(file_obj) |
| 79 | + for item in key_file.keys: |
| 80 | + self.assertIsInstance(item, SSHKey) |
| 81 | + self.assertEqual(len(key_file.keys), valid_keys_count) |
| 82 | + |
| 83 | + def check_invalid_file(self, file_str, expected_error): |
| 84 | + file_obj = StringIO(file_str) |
| 85 | + self.assertRaises(expected_error, AuthorizedKeysFile, file_obj) |
| 86 | + |
| 87 | + def test_disallow_options(self): |
| 88 | + file_obj = StringIO("""command="dump /home",no-pty,no-port-forwarding """ + DEFAULT_KEY) |
| 89 | + self.assertRaises(InvalidOptionsError, AuthorizedKeysFile, file_obj, disallow_options=True) |
| 90 | + file_obj.seek(0) |
| 91 | + key_file = AuthorizedKeysFile(file_obj) |
| 92 | + self.assertEqual(len(key_file.keys), 1) |
| 93 | + |
55 | 94 |
|
56 | 95 | def loop_options(options): |
57 | 96 | """ Loop over list of options and dynamically create tests """ |
@@ -106,11 +145,29 @@ def ch(pubkey, expected_error, **kwargs): |
106 | 145 | setattr(TestKeys, "test_%s_mode_%s" % (prefix_tmp, mode), ch(pubkey, expected_error, **kwargs)) |
107 | 146 |
|
108 | 147 |
|
| 148 | +def loop_authorized_keys(keyset): |
| 149 | + def ch(file_str, valid_keys_count): |
| 150 | + return lambda self: self.check_valid_file(file_str, valid_keys_count) |
| 151 | + for i, items in enumerate(keyset): |
| 152 | + prefix_tmp = "%s_%s" % (items[0], i) |
| 153 | + setattr(TestAuthorizedKeys, "test_%s" % prefix_tmp, ch(items[1], items[2])) |
| 154 | + |
| 155 | + |
| 156 | +def loop_invalid_authorized_keys(keyset): |
| 157 | + def ch(file_str, expected_error, **kwargs): |
| 158 | + return lambda self: self.check_invalid_file(file_str, expected_error, **kwargs) |
| 159 | + for i, items in enumerate(keyset): |
| 160 | + prefix_tmp = "%s_%s" % (items[0], i) |
| 161 | + setattr(TestAuthorizedKeys, "test_invalid_%s" % prefix_tmp, ch(items[1], items[2])) |
| 162 | + |
| 163 | + |
109 | 164 | loop_valid(list_of_valid_keys, "valid_key") |
110 | 165 | loop_valid(list_of_valid_keys_rfc4716, "valid_key_rfc4716") |
111 | 166 | loop_invalid(list_of_invalid_keys, "invalid_key") |
112 | 167 | loop_options(list_of_valid_options) |
113 | 168 | loop_invalid_options(list_of_invalid_options) |
| 169 | +loop_authorized_keys(list_of_authorized_keys) |
| 170 | +loop_invalid_authorized_keys(list_of_invalid_authorized_keys) |
114 | 171 |
|
115 | 172 | if __name__ == '__main__': |
116 | 173 | unittest.main() |
0 commit comments