|
4 | 4 |
|
5 | 5 | """ |
6 | 6 |
|
| 7 | +import sys |
7 | 8 | import unittest |
8 | | -from sshpubkeys import SSHKey |
| 9 | + |
| 10 | +from sshpubkeys import AuthorizedKeysFile, InvalidOptionsError, SSHKey |
| 11 | + |
| 12 | +from .authorized_keys import items as list_of_authorized_keys |
| 13 | +from .invalid_authorized_keys import items as list_of_invalid_authorized_keys |
| 14 | +from .invalid_keys import keys as list_of_invalid_keys |
| 15 | +from .invalid_options import options as list_of_invalid_options |
9 | 16 | from .valid_keys import keys as list_of_valid_keys |
10 | 17 | from .valid_keys_rfc4716 import keys as list_of_valid_keys_rfc4716 |
11 | 18 | from .invalid_keys import keys as list_of_invalid_keys |
12 | 19 | from .valid_options import options as list_of_valid_options |
13 | 20 | from .invalid_options import options as list_of_invalid_options |
| 21 | +from .authorized_keys import items as list_of_authorized_keys |
| 22 | +from .invalid_authorized_keys import items as list_of_invalid_authorized_keys |
| 23 | +from io import StringIO |
| 24 | + |
| 25 | +if sys.version_info.major == 2: |
| 26 | + from io import BytesIO as StringIO |
| 27 | +else: |
| 28 | + from io import StringIO |
| 29 | + |
| 30 | + |
| 31 | +DEFAULT_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEGODBKRjsFB/1v3pDRGpA6xR+QpOJg9vat0brlbUNDD" |
14 | 32 |
|
15 | 33 |
|
16 | 34 | class TestMisc(unittest.TestCase): |
@@ -52,6 +70,32 @@ def check_invalid_option(self, option, expected_error): |
52 | 70 | ssh = SSHKey() |
53 | 71 | self.assertRaises(expected_error, ssh.parse_options, option) |
54 | 72 |
|
| 73 | + def test_disallow_options(self): |
| 74 | + ssh = SSHKey(disallow_options=True) |
| 75 | + key = """command="dump /home",no-pty,no-port-forwarding """ + DEFAULT_KEY |
| 76 | + self.assertRaises(InvalidOptionsError, ssh.parse, key) |
| 77 | + |
| 78 | + |
| 79 | +class TestAuthorizedKeys(unittest.TestCase): |
| 80 | + |
| 81 | + def check_valid_file(self, file_str, valid_keys_count): |
| 82 | + file_obj = StringIO(file_str) |
| 83 | + key_file = AuthorizedKeysFile(file_obj) |
| 84 | + for item in key_file.keys: |
| 85 | + self.assertIsInstance(item, SSHKey) |
| 86 | + self.assertEqual(len(key_file.keys), valid_keys_count) |
| 87 | + |
| 88 | + def check_invalid_file(self, file_str, expected_error): |
| 89 | + file_obj = StringIO(file_str) |
| 90 | + self.assertRaises(expected_error, AuthorizedKeysFile, file_obj) |
| 91 | + |
| 92 | + def test_disallow_options(self): |
| 93 | + file_obj = StringIO("""command="dump /home",no-pty,no-port-forwarding """ + DEFAULT_KEY) |
| 94 | + self.assertRaises(InvalidOptionsError, AuthorizedKeysFile, file_obj, disallow_options=True) |
| 95 | + file_obj.seek(0) |
| 96 | + key_file = AuthorizedKeysFile(file_obj) |
| 97 | + self.assertEqual(len(key_file.keys), 1) |
| 98 | + |
55 | 99 |
|
56 | 100 | def loop_options(options): |
57 | 101 | """ Loop over list of options and dynamically create tests """ |
@@ -106,11 +150,29 @@ def ch(pubkey, expected_error, **kwargs): |
106 | 150 | setattr(TestKeys, "test_%s_mode_%s" % (prefix_tmp, mode), ch(pubkey, expected_error, **kwargs)) |
107 | 151 |
|
108 | 152 |
|
| 153 | +def loop_authorized_keys(keyset): |
| 154 | + def ch(file_str, valid_keys_count): |
| 155 | + return lambda self: self.check_valid_file(file_str, valid_keys_count) |
| 156 | + for i, items in enumerate(keyset): |
| 157 | + prefix_tmp = "%s_%s" % (items[0], i) |
| 158 | + setattr(TestAuthorizedKeys, "test_%s" % prefix_tmp, ch(items[1], items[2])) |
| 159 | + |
| 160 | + |
| 161 | +def loop_invalid_authorized_keys(keyset): |
| 162 | + def ch(file_str, expected_error, **kwargs): |
| 163 | + return lambda self: self.check_invalid_file(file_str, expected_error, **kwargs) |
| 164 | + for i, items in enumerate(keyset): |
| 165 | + prefix_tmp = "%s_%s" % (items[0], i) |
| 166 | + setattr(TestAuthorizedKeys, "test_invalid_%s" % prefix_tmp, ch(items[1], items[2])) |
| 167 | + |
| 168 | + |
109 | 169 | loop_valid(list_of_valid_keys, "valid_key") |
110 | 170 | loop_valid(list_of_valid_keys_rfc4716, "valid_key_rfc4716") |
111 | 171 | loop_invalid(list_of_invalid_keys, "invalid_key") |
112 | 172 | loop_options(list_of_valid_options) |
113 | 173 | loop_invalid_options(list_of_invalid_options) |
| 174 | +loop_authorized_keys(list_of_authorized_keys) |
| 175 | +loop_invalid_authorized_keys(list_of_invalid_authorized_keys) |
114 | 176 |
|
115 | 177 | if __name__ == '__main__': |
116 | 178 | unittest.main() |
0 commit comments