Skip to content

Commit 816dc5a

Browse files
committed
refresh data design v1
1 parent 12b123e commit 816dc5a

25 files changed

Lines changed: 3552 additions & 478 deletions

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
.DS_Store
66
build/
77
dist/
8+
experiment_data/
89

910
## Ignore Visual Studio temporary files, build results, and
1011
## files generated by popular Visual Studio add-ons.
@@ -405,3 +406,7 @@ FodyWeavers.xsd
405406
# JetBrains Rider
406407
*.sln.iml
407408
venv
409+
410+
411+
\.\NUL
412+
NUL

py-src/data_formulator/agents/agent_py_data_rec.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"recap": "..." // string, a short summary of the user's goal.
3333
"display_instruction": "..." // string, the even shorter verb phrase describing the users' goal.
3434
"recommendation": "..." // string, explain why this recommendation is made
35+
"input_tables": [...] // string[], describe names of the input tables that will be used in the transformation.
3536
"output_fields": [...] // string[], describe the desired output fields that the output data should have (i.e., the goal of transformed data), it's a good idea to preseve intermediate fields here
3637
"chart_type": "" // string, one of "point", "bar", "line", "area", "heatmap", "group_bar", 'boxplot'. "chart_type" should either be inferred from user instruction, or recommend if the user didn't specify any.
3738
"chart_encodings": {
@@ -65,6 +66,7 @@
6566
- if you mention column names from the input or the output data, highlight the text in **bold**.
6667
* the column can either be a column in the input data, or a new column that will be computed in the output data.
6768
* the mention don't have to be exact match, it can be semantically matching, e.g., if you mentioned "average score" in the text while the column to be computed is "Avg_Score", you should still highlight "**average score**" in the text.
69+
- determine "input_tables", the names of a subset of input tables from [CONTEXT] section that will be used to achieve the user's goal.
6870
- "chart_type" must be one of "point", "bar", "line", "area", "heatmap", "group_bar", "boxplot"
6971
- "chart_encodings" should specify which fields should be used to create the visualization
7072
- decide which visual channels should be used to create the visualization appropriate for the chart type.
@@ -157,9 +159,11 @@ def transform_data(df1, df2, ...):
157159
```
158160
159161
note:
160-
- if the user provided one table, then it should be `def transform_data(df1)`, if the user provided multiple tables, then it should be `def transform_data(df1, df2, ...)` and you should consider the join between tables to derive the output.
161-
- **VERY IMPORTANT** the number of arguments in the function must match the number of tables provided, and the order of arguments must match the order of tables provided.
162-
- you can use intuitive table names to refer to the input dataframes, for example, if the user provided two tables city and weather, you can use `transform_data(df_city, df_weather)` to refer to the two dataframes, as long as the number and order of the arguments match the number and order of the tables provided.
162+
- decide the function signature based on the number of tables you decided in the previous step "input_tables":
163+
- if you decide there will only be one input table, then function signature should be `def transform_data(df1)`
164+
- if you decided there will be k input tables, then function signature should be `def transform_data(df_1, df_2, ..., df_k)`.
165+
- instead of using generic names like df1, df2, ..., try to use intuitive table names for function arguments, for example, if you have input_tables: ["City", "Weather"]`, you can use `transform_data(df_city, df_weather)` to refer to the two dataframes.
166+
- **VERY IMPORTANT** the number of arguments in the function signature must be the same as the number of tables provided in "input_tables", and the order of arguments must match the order of tables provided in "input_tables".
163167
- datetime objects handling:
164168
- if the output field is year, convert it to number, if it is year-month / year-month-day, convert it to string object (e.g., "2020-01" / "2020-01-01").
165169
- if the output is time only: convert hour to number if it's just the hour (e.g., 10), but convert hour:min or h:m:s to string object (e.g., "10:30", "10:30:45")
@@ -205,6 +209,7 @@ def transform_data(df1, df2, ...):
205209
"display_instruction": "Rank students by average scores",
206210
"mode": "infer",
207211
"recommendation": "To rank students based on their average scores, we need to calculate the average score for each student, then sort the data, and finally assign a rank to each student based on their average score.",
212+
"input_tables": ["student_exam"],
208213
"output_fields": ["student", "major", "average_score", "rank"],
209214
"chart_type": "bar",
210215
"chart_encodings": {"x": "student", "y": "average_score"},
@@ -260,15 +265,41 @@ def process_gpt_response(self, input_tables, messages, response):
260265
if len(json_blocks) > 0:
261266
refined_goal = json_blocks[0]
262267
else:
263-
refined_goal = { 'mode': "", 'recommendation': "", 'output_fields': [], 'chart_encodings': {}, 'chart_type': "" }
268+
refined_goal = { 'mode': "", 'recommendation': "", 'input_tables': [], 'output_fields': [], 'chart_encodings': {}, 'chart_type': "" }
264269

265270
code_blocks = extract_code_from_gpt_response(choice.message.content + "\n", "python")
266271

267272
if len(code_blocks) > 0:
268273
code_str = code_blocks[-1]
269274

270275
try:
271-
result = py_sandbox.run_transform_in_sandbox2020(code_str, [pd.DataFrame.from_records(t['rows']) for t in input_tables], self.exec_python_in_subprocess)
276+
# Check if input_tables is available
277+
if not input_tables:
278+
result = {'status': 'error', 'code': code_str, 'content': "No input tables available."}
279+
else:
280+
# Determine which tables to use based on refined_goal
281+
if 'input_tables' in refined_goal and isinstance(refined_goal['input_tables'], list) and len(refined_goal['input_tables']) > 0:
282+
# Use only specified tables - validate all exist
283+
table_name_map = {t['name']: t for t in input_tables}
284+
tables_to_use = []
285+
missing_tables = []
286+
287+
for table_name in refined_goal['input_tables']:
288+
if table_name in table_name_map:
289+
tables_to_use.append(table_name_map[table_name])
290+
else:
291+
missing_tables.append(table_name)
292+
293+
# Error if any specified table is missing
294+
if missing_tables:
295+
available_table_names = [t['name'] for t in input_tables]
296+
result = {'status': 'error', 'code': code_str, 'content': f"Table(s) '{', '.join(missing_tables)}' specified in 'input_tables' not found. Available tables: {', '.join(available_table_names)}"}
297+
else:
298+
result = py_sandbox.run_transform_in_sandbox2020(code_str, [pd.DataFrame.from_records(t['rows']) for t in tables_to_use], self.exec_python_in_subprocess)
299+
else:
300+
# No input_tables specified in refined_goal, use all input_tables
301+
result = py_sandbox.run_transform_in_sandbox2020(code_str, [pd.DataFrame.from_records(t['rows']) for t in input_tables], self.exec_python_in_subprocess)
302+
272303
result['code'] = code_str
273304

274305
if result['status'] == 'ok':

py-src/data_formulator/agents/agent_py_data_transform.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
- if you mention column names from the input or the output data, highlight the text in **bold**.
3535
* the column can either be a column in the input data, or a new column that will be computed in the output data.
3636
* the mention don't have to be exact match, it can be semantically matching, e.g., if you mentioned "average score" in the text while the column to be computed is "Avg_Score", you should still highlight "**average score**" in the text.
37+
- determine "input_tables", the names of a subset of input tables from [CONTEXT] section that will be used to achieve the user's goal.
3738
- determine "output_fields", the desired fields that the output data should have to achieve the user's goal, it's a good idea to include intermediate fields here.
3839
- then decide "chart_encodings", which maps visualization channels (x, y, color, size, opacity, facet, etc.) to a subset of "output_fields" that will be visualized,
3940
- the "chart_encodings" should be created to support the user's "chart_type".
@@ -48,7 +49,7 @@
4849
- e.g., they may mention "use B metric instead" while A metric is in provided fields, in this case, you should update "chart_encodings" to update A metric with B metric.
4950
- guide on statistical analysis:
5051
- when the user asks for forecasting or regression analysis, you should consider the following:
51-
- the output should be a long format table where actual x, y pairs and predicted x, y pairs are included in the X, Y columns, they are differentiated with a third column "is_predicted" that is a boolean field.
52+
- the output should be a long format table where actual x, y pairs and predicted x, y pairs are included in the X, Y columns, they are differentiated with a third column "is_predicted".
5253
- i.e., if the user ask for forecasting based on two columns T and Y, the output should be three columns: T, Y, is_predicted, where
5354
- T, Y columns contain BOTH original values from the data and predicted values from the data.
5455
- is_predicted is a boolean field to indicate whether the x, y pairs are original values from the data or predicted / regression values from the data.
@@ -65,6 +66,7 @@
6566
{
6667
"detailed_instruction": "..." // string, elaborate user instruction with details if the user
6768
"display_instruction": "..." // string, the short verb phrase describing the users' goal.
69+
"input_tables": [...] // string[], describe names of the input tables that will be used in the transformation.
6870
"output_fields": [...] // string[], describe the desired output fields that the output data should have based on the user's goal, it's a good idea to preserve intermediate fields here (i.e., the goal of transformed data)
6971
"chart_encodings": {
7072
"x": "",
@@ -79,8 +81,8 @@
7981
}
8082
```
8183
82-
2. Then, write a python function based on the refined goal, the function input is a dataframe "df" (or multiple dataframes based on tables presented in the [CONTEXT] section) and the output is the transformed dataframe "transformed_df". "transformed_df" should contain all "output_fields" from the refined goal.
83-
The python function must follow the template provided in [TEMPLATE], do not import any other libraries or modify function name. The function should be as simple as possible and easily readable.
84+
2. Then, write a python function based on the refined goal, the function input is a dataframe "df" (or multiple dataframes based on tables described in "input_tables") and the output is the transformed dataframe "transformed_df". "transformed_df" should contain all "output_fields" from the refined goal.
85+
The python function must follow the template provided in [TEMPLATE], only import libraries allowed in the template, do not modify function name. The function should be as simple as possible and easily readable.
8486
If there is no data transformation needed based on "output_fields", the transformation function can simply "return df".
8587
8688
[TEMPLATE]
@@ -97,9 +99,11 @@ def transform_data(df1, df2, ...):
9799
```
98100
99101
note:
100-
- if the user provided one table, then it should be `def transform_data(df1)`, if the user provided multiple tables, then it should be `def transform_data(df1, df2, ...)` and you should consider the join between tables to derive the output.
101-
- **VERY IMPORTANT** the number of arguments in the function must match the number of tables provided, and the order of arguments must match the order of tables provided.
102-
- try to use intuitive table names to refer to the input dataframes, for example, if the user provided two tables city and weather, you can use `transform_data(df_city, df_weather)` to refer to the two dataframes, as long as the number and order of the arguments match the number and order of the tables provided.
102+
- decide the function signature based on the number of tables you decided in the previous step "input_tables":
103+
- if you decide there will only be one input table, then function signature should be `def transform_data(df1)`
104+
- if you decided there will be k input tables, then function signature should be `def transform_data(df_1, df_2, ..., df_k)`.
105+
- instead of using generic names like df1, df2, ..., try to use intuitive table names for function arguments, for example, if you have input_tables: ["City", "Weather"]`, you can use `transform_data(df_city, df_weather)` to refer to the two dataframes.
106+
- **VERY IMPORTANT** the number of arguments in the function signature must be the same as the number of tables provided in "input_tables", and the order of arguments must match the order of tables provided in "input_tables".
103107
- datetime objects handling:
104108
- if the output field is year, convert it to number, if it is year-month / year-month-day, convert it to string object (e.g., "2020-01" / "2020-01-01").
105109
- if the output is time only: convert hour to number if it's just the hour (e.g., 10), but convert hour:min or h:m:s to string object (e.g., "10:30", "10:30:45")
@@ -202,6 +206,7 @@ def transform_data(df):
202206
203207
{
204208
"detailed_instruction": "Create a scatter plot to compare Seattle and Atlanta temperatures with Seattle temperatures on the x-axis and Atlanta temperatures on the y-axis. Color the points by which city is warmer.",
209+
"input_tables": ["weather_seattle_atlanta"],
205210
"output_fields": ["Date", "Seattle Temperature", "Atlanta Temperature", "Warmer City"],
206211
"chart_encodings": {"x": "Seattle Temperature", "y": "Atlanta Temperature", "color": "Warmer City"},
207212
"reason": "To compare Seattle and Atlanta temperatures with Seattle temperatures on the x-axis and Atlanta temperatures on the y-axis, and color points by which city is warmer, separate temperature fields for Seattle and Atlanta are required. Additionally, a new field 'Warmer City' is needed to indicate which city is warmer."
@@ -212,7 +217,7 @@ def transform_data(df):
212217
import collections
213218
import numpy as np
214219
215-
def transform_data(df):
220+
def transform_data(df_weather_seattle_atlanta):
216221
# Pivot the dataframe to have separate columns for Seattle and Atlanta temperatures
217222
df_pivot = df.pivot(index='Date', columns='City', values='Temperature').reset_index()
218223
df_pivot.columns = ['Date', 'Atlanta Temperature', 'Seattle Temperature']
@@ -260,15 +265,41 @@ def process_gpt_response(self, input_tables, messages, response):
260265
if len(json_blocks) > 0:
261266
refined_goal = json_blocks[0]
262267
else:
263-
refined_goal = {'chart_encodings': {}, 'instruction': '', 'reason': ''}
268+
refined_goal = {'chart_encodings': {}, 'instruction': '', 'reason': '', 'input_tables': []}
264269

265270
code_blocks = extract_code_from_gpt_response(choice.message.content + "\n", "python")
266271

267272
if len(code_blocks) > 0:
268273
code_str = code_blocks[-1]
269274

270275
try:
271-
result = py_sandbox.run_transform_in_sandbox2020(code_str, [pd.DataFrame.from_records(t['rows']) for t in input_tables], self.exec_python_in_subprocess)
276+
# Check if input_tables is available
277+
if not input_tables:
278+
result = {'status': 'error', 'code': code_str, 'content': "No input tables available."}
279+
else:
280+
# Determine which tables to use based on refined_goal
281+
if 'input_tables' in refined_goal and isinstance(refined_goal['input_tables'], list) and len(refined_goal['input_tables']) > 0:
282+
# Use only specified tables - validate all exist
283+
table_name_map = {t['name']: t for t in input_tables}
284+
tables_to_use = []
285+
missing_tables = []
286+
287+
for table_name in refined_goal['input_tables']:
288+
if table_name in table_name_map:
289+
tables_to_use.append(table_name_map[table_name])
290+
else:
291+
missing_tables.append(table_name)
292+
293+
# Error if any specified table is missing
294+
if missing_tables:
295+
available_table_names = [t['name'] for t in input_tables]
296+
result = {'status': 'error', 'code': code_str, 'content': f"Table(s) '{', '.join(missing_tables)}' specified in 'input_tables' not found. Available tables: {', '.join(available_table_names)}"}
297+
else:
298+
result = py_sandbox.run_transform_in_sandbox2020(code_str, [pd.DataFrame.from_records(t['rows']) for t in tables_to_use], self.exec_python_in_subprocess)
299+
else:
300+
# No input_tables specified in refined_goal, use all input_tables
301+
result = py_sandbox.run_transform_in_sandbox2020(code_str, [pd.DataFrame.from_records(t['rows']) for t in input_tables], self.exec_python_in_subprocess)
302+
272303
result['code'] = code_str
273304

274305
if result['status'] == 'ok':

0 commit comments

Comments
 (0)