1+ import json
2+ import pandas as pd
3+ import duckdb
4+ import os
5+
6+ from data_formulator .data_loader .external_data_loader import ExternalDataLoader , sanitize_table_name
7+ from typing import Dict , Any , List
8+
9+ class S3DataLoader (ExternalDataLoader ):
10+
11+ @staticmethod
12+ def list_params () -> List [Dict [str , Any ]]:
13+ params_list = [
14+ {"name" : "aws_access_key_id" , "type" : "string" , "required" : True , "default" : "" , "description" : "AWS access key ID" },
15+ {"name" : "aws_secret_access_key" , "type" : "string" , "required" : True , "default" : "" , "description" : "AWS secret access key" },
16+ {"name" : "aws_session_token" , "type" : "string" , "required" : False , "default" : "" , "description" : "AWS session token (required for temporary credentials)" },
17+ {"name" : "region_name" , "type" : "string" , "required" : True , "default" : "us-east-1" , "description" : "AWS region name" },
18+ {"name" : "bucket" , "type" : "string" , "required" : True , "default" : "" , "description" : "S3 bucket name" }
19+ ]
20+ return params_list
21+
22+ def __init__ (self , params : Dict [str , Any ], duck_db_conn : duckdb .DuckDBPyConnection ):
23+ self .params = params
24+ self .duck_db_conn = duck_db_conn
25+
26+ # Extract parameters
27+ self .aws_access_key_id = params .get ("aws_access_key_id" , "" )
28+ self .aws_secret_access_key = params .get ("aws_secret_access_key" , "" )
29+ self .aws_session_token = params .get ("aws_session_token" , "" )
30+ self .region_name = params .get ("region_name" , "us-east-1" )
31+ self .bucket = params .get ("bucket" , "" )
32+
33+ # Install and load the httpfs extension for S3 access
34+ self .duck_db_conn .install_extension ("httpfs" )
35+ self .duck_db_conn .load_extension ("httpfs" )
36+
37+ # Set AWS credentials for DuckDB
38+ self .duck_db_conn .execute (f"SET s3_region='{ self .region_name } '" )
39+ self .duck_db_conn .execute (f"SET s3_access_key_id='{ self .aws_access_key_id } '" )
40+ self .duck_db_conn .execute (f"SET s3_secret_access_key='{ self .aws_secret_access_key } '" )
41+ if self .aws_session_token : # Add this block
42+ self .duck_db_conn .execute (f"SET s3_session_token='{ self .aws_session_token } '" )
43+
44+ def list_tables (self ) -> List [Dict [str , Any ]]:
45+ # Use boto3 to list objects in the bucket
46+ import boto3
47+
48+ s3_client = boto3 .client (
49+ 's3' ,
50+ aws_access_key_id = self .aws_access_key_id ,
51+ aws_secret_access_key = self .aws_secret_access_key ,
52+ aws_session_token = self .aws_session_token if self .aws_session_token else None ,
53+ region_name = self .region_name
54+ )
55+
56+ # List objects in the bucket
57+ response = s3_client .list_objects_v2 (Bucket = self .bucket )
58+
59+ results = []
60+
61+ if 'Contents' in response :
62+ for obj in response ['Contents' ]:
63+ key = obj ['Key' ]
64+
65+ # Skip directories and non-data files
66+ if key .endswith ('/' ) or not self ._is_supported_file (key ):
67+ continue
68+
69+ # Create S3 URL
70+ s3_url = f"s3://{ self .bucket } /{ key } "
71+
72+ try :
73+ # Choose the appropriate read function based on file extension
74+ if s3_url .lower ().endswith ('.parquet' ):
75+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_parquet('{ s3_url } ') LIMIT 10" ).df ()
76+ elif s3_url .lower ().endswith ('.json' ) or s3_url .lower ().endswith ('.jsonl' ):
77+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_json_auto('{ s3_url } ') LIMIT 10" ).df ()
78+ elif s3_url .lower ().endswith ('.csv' ): # Default to CSV for other formats
79+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_csv_auto('{ s3_url } ') LIMIT 10" ).df ()
80+
81+ # Get column information
82+ columns = [{
83+ 'name' : col ,
84+ 'type' : str (sample_df [col ].dtype )
85+ } for col in sample_df .columns ]
86+
87+ # Get sample data
88+ sample_rows = json .loads (sample_df .to_json (orient = "records" ))
89+
90+ # Estimate row count (this is approximate for CSV files)
91+ row_count = self ._estimate_row_count (s3_url )
92+
93+ table_metadata = {
94+ "row_count" : row_count ,
95+ "columns" : columns ,
96+ "sample_rows" : sample_rows
97+ }
98+
99+ results .append ({
100+ "name" : s3_url ,
101+ "metadata" : table_metadata
102+ })
103+ except Exception as e :
104+ # Skip files that can't be read
105+ print (f"Error reading { s3_url } : { e } " )
106+ continue
107+
108+ return results
109+
110+ def _is_supported_file (self , key : str ) -> bool :
111+ """Check if the file type is supported by DuckDB."""
112+ supported_extensions = ['.csv' , '.parquet' , '.json' , '.jsonl' ]
113+ return any (key .lower ().endswith (ext ) for ext in supported_extensions )
114+
115+ def _estimate_row_count (self , s3_url : str ) -> int :
116+ """Estimate the number of rows in a file."""
117+ try :
118+ # For parquet files, we can get the exact count
119+ if s3_url .lower ().endswith ('.parquet' ):
120+ count = self .duck_db_conn .execute (f"SELECT COUNT(*) FROM read_parquet('{ s3_url } ')" ).fetchone ()[0 ]
121+ return count
122+
123+ # For CSV files, we'll sample the file to estimate size
124+ sample_size = 1000
125+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_csv_auto('{ s3_url } ') LIMIT { sample_size } " ).df ()
126+
127+ # Get file size from S3
128+ import boto3
129+ s3_client = boto3 .client (
130+ 's3' ,
131+ aws_access_key_id = self .aws_access_key_id ,
132+ aws_secret_access_key = self .aws_secret_access_key ,
133+ aws_session_token = self .aws_session_token if self .aws_session_token else None ,
134+ region_name = self .region_name
135+ )
136+
137+ key = s3_url .replace (f"s3://{ self .bucket } /" , "" )
138+ response = s3_client .head_object (Bucket = self .bucket , Key = key )
139+ file_size = response ['ContentLength' ]
140+
141+ # Estimate based on sample size and file size
142+ if len (sample_df ) > 0 :
143+ # Calculate average row size in bytes
144+ avg_row_size = file_size / len (sample_df )
145+ estimated_rows = int (file_size / avg_row_size )
146+ return min (estimated_rows , 1000000 ) # Cap at 1 million for UI performance
147+
148+ return 0
149+ except Exception as e :
150+ print (f"Error estimating row count for { s3_url } : { e } " )
151+ return 0
152+
153+ def ingest_data (self , table_name : str , name_as : str = None , size : int = 1000000 ):
154+ if name_as is None :
155+ name_as = table_name .split ('/' )[- 1 ].split ('.' )[0 ]
156+
157+ name_as = sanitize_table_name (name_as )
158+
159+ # Determine file type and use appropriate DuckDB function
160+ if table_name .lower ().endswith ('.csv' ):
161+ self .duck_db_conn .execute (f"""
162+ CREATE OR REPLACE TABLE main.{ name_as } AS
163+ SELECT * FROM read_csv_auto('{ table_name } ')
164+ LIMIT { size }
165+ """ )
166+ elif table_name .lower ().endswith ('.parquet' ):
167+ self .duck_db_conn .execute (f"""
168+ CREATE OR REPLACE TABLE main.{ name_as } AS
169+ SELECT * FROM read_parquet('{ table_name } ')
170+ LIMIT { size }
171+ """ )
172+ elif table_name .lower ().endswith ('.json' ) or table_name .lower ().endswith ('.jsonl' ):
173+ self .duck_db_conn .execute (f"""
174+ CREATE OR REPLACE TABLE main.{ name_as } AS
175+ SELECT * FROM read_json_auto('{ table_name } ')
176+ LIMIT { size }
177+ """ )
178+ else :
179+ raise ValueError (f"Unsupported file type: { table_name } " )
180+
181+ def view_query_sample (self , query : str ) -> List [Dict [str , Any ]]:
182+ return self .duck_db_conn .execute (query ).df ().head (10 ).to_dict (orient = "records" )
183+
184+ def ingest_data_from_query (self , query : str , name_as : str ):
185+ # Execute the query and get results as a DataFrame
186+ df = self .duck_db_conn .execute (query ).df ()
187+ # Use the base class's method to ingest the DataFrame
188+ self .ingest_df_to_duckdb (df , name_as )
0 commit comments