Skip to content

Commit 0e332ce

Browse files
[RTY-260032]: add modular approach and whitelist-blacklist authorization
1 parent 48a970f commit 0e332ce

3 files changed

Lines changed: 67 additions & 9 deletions

File tree

app/routes.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,19 @@
3434
remove_cache_key,
3535
rev_cache,
3636
)
37-
from app.utils.config import DOMAIN, MAX_RECENT_URLS, CACHE_PURGE_TOKEN, QR_DIR
38-
from app.utils.helper import generate_code, is_valid_url, format_date
37+
from app.utils.config import (
38+
DOMAIN,
39+
MAX_RECENT_URLS,
40+
CACHE_PURGE_TOKEN,
41+
QR_DIR,
42+
)
43+
from app.utils.helper import (
44+
generate_code,
45+
sanitize_url,
46+
is_valid_url,
47+
authorize_url,
48+
format_date,
49+
)
3950
from app.utils.qr import generate_qr_with_logo
4051

4152
# templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
@@ -98,11 +109,18 @@ async def create_short_url(
98109
qr_type: str = Form("short"),
99110
):
100111
session = request.session
112+
original_url = sanitize_url(original_url) # sanitize the URL input
101113

102-
if not original_url or not is_valid_url(original_url):
114+
if not original_url or not is_valid_url(original_url): # validate the URL
103115
session["error"] = "Please enter a valid URL."
104116
return RedirectResponse("/", status_code=status.HTTP_303_SEE_OTHER)
105117

118+
if not authorize_url(
119+
original_url
120+
): # authorize the URL based on whitelist/blacklist
121+
session["error"] = "This domain is not allowed."
122+
return RedirectResponse("/", status_code=status.HTTP_303_SEE_OTHER)
123+
106124
short_code: Optional[str] = get_short_from_cache(original_url)
107125

108126
if not short_code and db.is_connected():
@@ -329,10 +347,14 @@ class ShortenRequest(BaseModel):
329347

330348
@api_v1.post("/shorten")
331349
def shorten_api(payload: ShortenRequest):
332-
original_url = payload.url
350+
original_url = sanitize_url(payload.url)
351+
333352
if not is_valid_url(original_url):
334353
return JSONResponse(status_code=400, content={"error": "INVALID_URL"})
335354

355+
if not authorize_url(original_url):
356+
return JSONResponse(status_code=400, content={"error": "DOMAIN_NOT_ALLOWED"})
357+
336358
short_code = get_short_from_cache(original_url)
337359
if not short_code:
338360
short_code = generate_code()

app/utils/config.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,16 @@ def _get_int(key: str, default: int) -> int:
8282
MAX_URL_LENGTH = 2048
8383

8484
# for making the qr constant
85-
# Base project paths
8685
BASE_DIR = Path(__file__).resolve().parent.parent # app/
8786
PROJECT_ROOT = BASE_DIR.parent # project root
88-
# QR directory constant
89-
QR_DIR = PROJECT_ROOT / "assets" / "images" / "qr"
87+
QR_DIR = PROJECT_ROOT / "assets" / "images" / "qr" # QR directory constant
88+
89+
90+
# for the check of the url which is blacklist or whitelist
91+
whitelist_urls: set[str] = set()
92+
blacklist_urls: set[str] = {
93+
"malware.com",
94+
"phishing.site",
95+
"badsite.test",
96+
"spam.test",
97+
}

app/utils/helper.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
from datetime import datetime, timezone
44
from zoneinfo import ZoneInfo
55
from typing import Union
6-
from app.utils.config import SHORT_CODE_LENGTH
6+
from app.utils.config import SHORT_CODE_LENGTH, whitelist_urls, blacklist_urls
77
from urllib.parse import urlparse
88
import ipaddress
99

1010

11+
# for sanitization the url
12+
def sanitize_url(url: str) -> str:
13+
return url.strip()
14+
15+
16+
# for validating the url
1117
def is_valid_url(url: str) -> bool:
12-
url = url.strip() # sanitize here
1318

1419
try:
1520
parsed = urlparse(url)
@@ -43,6 +48,29 @@ def is_valid_url(url: str) -> bool:
4348
return False
4449

4550

51+
# for authorizing the url based on whitelist and blacklist
52+
def authorize_url(url: str) -> bool:
53+
"""Check whitelist / blacklist rules."""
54+
hostname = urlparse(url).hostname
55+
56+
if hostname is None:
57+
return False
58+
59+
hostname = hostname.lower()
60+
61+
# block blacklist domains
62+
if any(hostname.endswith(domain) for domain in blacklist_urls):
63+
return False
64+
65+
# allow only whitelist domains (if defined)
66+
if whitelist_urls and not any(
67+
hostname.endswith(domain) for domain in whitelist_urls
68+
):
69+
return False
70+
71+
return True
72+
73+
4674
def generate_code(length: int = SHORT_CODE_LENGTH) -> str:
4775
chars = string.ascii_letters + string.digits
4876
return "".join(random.choice(chars) for _ in range(length))

0 commit comments

Comments
 (0)