Skip to content

Commit 3739f28

Browse files
authored
Feature/dataset notebook (#11)
* Add Dataset notebook * Added configurable dataset_variable_position to dataset plots * Add to datasets.ipynb * ⏺ Fixed the review issues: 1. Notebook comment - Clarified that to_array() returns (variable, ...) but xpx() reorders it 2. Negative indexing - Now supports Python-style: -1=last, -2=second-to-last, etc. 3. Import - Moved _options to module level
1 parent 473c12c commit 3739f28

4 files changed

Lines changed: 294 additions & 2 deletions

File tree

docs/examples/datasets.ipynb

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Dataset Plotting\n",
8+
"\n",
9+
"Plot multiple variables from an xarray Dataset with automatic or custom slot assignment."
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import numpy as np\n",
19+
"import xarray as xr\n",
20+
"\n",
21+
"from xarray_plotly import config, xpx\n",
22+
"\n",
23+
"config.notebook()"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"# Create a Dataset with multiple variables\n",
33+
"time = np.arange(50)\n",
34+
"cities = [\"NYC\", \"LA\", \"Chicago\"]\n",
35+
"\n",
36+
"ds = xr.Dataset(\n",
37+
" {\n",
38+
" \"temperature\": ([\"time\", \"city\"], 20 + 5 * np.random.randn(50, 3).cumsum(axis=0) / 10),\n",
39+
" \"humidity\": ([\"time\", \"city\"], 50 + 10 * np.random.randn(50, 3).cumsum(axis=0) / 10),\n",
40+
" \"pressure\": ([\"time\", \"city\"], 1013 + np.random.randn(50, 3).cumsum(axis=0)),\n",
41+
" },\n",
42+
" coords={\"time\": time, \"city\": cities},\n",
43+
")\n",
44+
"ds"
45+
]
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"metadata": {},
50+
"source": [
51+
"## Plot All Variables\n",
52+
"\n",
53+
"When you call a plot method on a Dataset without specifying `var`, all variables are combined into a single DataArray with a new `\"variable\"` dimension:"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"# All variables: time -> x, variable -> color, city -> line_dash\n",
63+
"xpx(ds).line()"
64+
]
65+
},
66+
{
67+
"cell_type": "markdown",
68+
"metadata": {},
69+
"source": [
70+
"## Control Where \"variable\" Goes\n",
71+
"\n",
72+
"The `\"variable\"` dimension can be assigned to any slot:"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"# Variables as facet columns\n",
82+
"xpx(ds).line(facet_col=\"variable\")"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"# Variables as rows, cities as columns\n",
92+
"xpx(ds).line(facet_row=\"variable\", facet_col=\"city\")"
93+
]
94+
},
95+
{
96+
"cell_type": "markdown",
97+
"metadata": {},
98+
"source": [
99+
"## Configure Default \"variable\" Position\n",
100+
"\n",
101+
"By default, `\"variable\"` is placed as the **second** dimension so it maps to `color`. This keeps your first dimension (e.g., time) on the x-axis.\n",
102+
"\n",
103+
"You can change this globally with `config.set_options()`:"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"metadata": {},
110+
"outputs": [],
111+
"source": [
112+
"# Default: position=1 (second) -> variable goes to color\n",
113+
"# Note: to_array() puts \"variable\" first, but xpx() reorders it to position 1\n",
114+
"print(\"Raw to_array() dims:\", ds.to_array().dims) # (variable, time, city)\n",
115+
"print(\"After xpx reorder: (time, variable, city)\") # time->x, variable->color\n",
116+
"xpx(ds).line(title=\"Default: variable as color (position=1)\")"
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"# Position 0: variable goes first (x-axis) - usually not what you want!\n",
126+
"with config.set_options(dataset_variable_position=0):\n",
127+
" fig = xpx(ds).line(title=\"position=0: variable on x-axis (probably not desired)\")\n",
128+
"fig"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"# Position -1: variable goes last -> city gets color, variable gets line_dash\n",
138+
"with config.set_options(dataset_variable_position=-1):\n",
139+
" fig = xpx(ds).line(title=\"position=-1: variable as line_dash\")\n",
140+
"fig"
141+
]
142+
},
143+
{
144+
"cell_type": "markdown",
145+
"metadata": {},
146+
"source": [
147+
"## Plot a Single Variable\n",
148+
"\n",
149+
"Use `var=\"name\"` to plot just one variable:"
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": null,
155+
"metadata": {},
156+
"outputs": [],
157+
"source": [
158+
"xpx(ds).line(var=\"temperature\", title=\"Temperature Only\")"
159+
]
160+
},
161+
{
162+
"cell_type": "markdown",
163+
"metadata": {},
164+
"source": [
165+
"## Different Plot Types"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"# Bar chart - latest values by city\n",
175+
"xpx(ds.isel(time=-1)).bar(x=\"city\", color=\"variable\", barmode=\"group\")"
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": null,
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"# Box plot - distribution by variable\n",
185+
"xpx(ds).box(x=\"variable\", color=\"city\")"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": null,
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"# Area chart\n",
195+
"xpx(ds).area(var=\"humidity\", title=\"Humidity Over Time\")"
196+
]
197+
},
198+
{
199+
"cell_type": "code",
200+
"execution_count": null,
201+
"metadata": {},
202+
"outputs": [],
203+
"source": [
204+
"# Scatter\n",
205+
"xpx(ds).scatter(var=\"temperature\", title=\"Temperature Scatter\")"
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": null,
211+
"metadata": {},
212+
"outputs": [],
213+
"source": [
214+
"# Pie chart - snapshot at one time\n",
215+
"xpx(ds.isel(time=-1)).pie(var=\"temperature\", names=\"city\", title=\"Temperature Distribution\")"
216+
]
217+
},
218+
{
219+
"cell_type": "markdown",
220+
"metadata": {},
221+
"source": [
222+
"## Combining Slot Assignments\n",
223+
"\n",
224+
"Mix explicit assignments with auto-assignment:"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": null,
230+
"metadata": {},
231+
"outputs": [],
232+
"source": [
233+
"# Explicit: variable -> facet_col, let city auto-assign to color\n",
234+
"xpx(ds).line(facet_col=\"variable\", color=\"city\")"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"# Skip color slot with None\n",
244+
"xpx(ds).line(var=\"temperature\", color=None)"
245+
]
246+
}
247+
],
248+
"metadata": {
249+
"kernelspec": {
250+
"display_name": "Python 3",
251+
"language": "python",
252+
"name": "python3"
253+
},
254+
"language_info": {
255+
"name": "python",
256+
"version": "3.12.0"
257+
}
258+
},
259+
"nbformat": 4,
260+
"nbformat_minor": 4
261+
}

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ nav:
6969
- Getting Started: getting-started.ipynb
7070
- Examples:
7171
- Plot Types: examples/plot-types.ipynb
72+
- Dataset Plotting: examples/datasets.ipynb
7273
- Dimensions & Facets: examples/dimensions.ipynb
7374
- Plotly Express Options: examples/kwargs.ipynb
7475
- Figure Customization: examples/figure.ipynb

xarray_plotly/accessor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from xarray_plotly import plotting
99
from xarray_plotly.common import SlotValue, auto
10+
from xarray_plotly.config import _options
1011

1112

1213
class DataArrayPlotlyAccessor:
@@ -349,9 +350,26 @@ def __dir__(self) -> list[str]:
349350
return list(self.__all__) + list(super().__dir__())
350351

351352
def _get_dataarray(self, var: str | None) -> DataArray:
352-
"""Get DataArray from Dataset, either single var or all via to_array()."""
353+
"""Get DataArray from Dataset, either single var or all via to_array().
354+
355+
When combining all variables, "variable" is placed at the position
356+
specified by config.dataset_variable_position (default 1, second position).
357+
Supports Python-style negative indexing: -1 = last, -2 = second-to-last, etc.
358+
"""
353359
if var is None:
354-
return self._ds.to_array(dim="variable")
360+
da = self._ds.to_array(dim="variable")
361+
pos = _options.dataset_variable_position
362+
# Move "variable" to configured position
363+
if len(da.dims) > 1 and pos != 0:
364+
dims = list(da.dims)
365+
dims.remove("variable")
366+
# Use Python-style indexing (handles negative indices correctly)
367+
# Clamp to valid range: -1 -> last, -2 -> second-to-last, etc.
368+
n = len(dims)
369+
insert_pos = max(0, n + pos + 1) if pos < 0 else min(pos, n)
370+
dims.insert(insert_pos, "variable")
371+
da = da.transpose(*dims)
372+
return da
355373
return self._ds[var]
356374

357375
def line(

xarray_plotly/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class Options:
5858
label_include_units: Append units to labels. Default True.
5959
label_unit_format: Format string for units. Use `{units}` as placeholder.
6060
slot_orders: Slot orders per plot type. Keys are plot types, values are tuples.
61+
dataset_variable_position: Position of "variable" dim when plotting all Dataset
62+
variables. Default 1 (second position, typically color). Set to 0 for first
63+
position (x-axis), or -1 for last position.
6164
"""
6265

6366
label_use_long_name: bool = True
@@ -67,6 +70,7 @@ class Options:
6770
slot_orders: dict[str, tuple[str, ...]] = field(
6871
default_factory=lambda: dict(DEFAULT_SLOT_ORDERS)
6972
)
73+
dataset_variable_position: int = 1
7074

7175
def to_dict(self) -> dict[str, Any]:
7276
"""Return options as a dictionary."""
@@ -76,6 +80,7 @@ def to_dict(self) -> dict[str, Any]:
7680
"label_include_units": self.label_include_units,
7781
"label_unit_format": self.label_unit_format,
7882
"slot_orders": self.slot_orders,
83+
"dataset_variable_position": self.dataset_variable_position,
7984
}
8085

8186

@@ -106,6 +111,7 @@ def set_options(
106111
label_include_units: bool | None = None,
107112
label_unit_format: str | None = None,
108113
slot_orders: dict[str, tuple[str, ...]] | None = None,
114+
dataset_variable_position: int | None = None,
109115
) -> Generator[None, None, None]:
110116
"""Set xarray_plotly options globally or as a context manager.
111117
@@ -115,6 +121,8 @@ def set_options(
115121
label_include_units: Append units to labels.
116122
label_unit_format: Format string for units. Use `{units}` as placeholder.
117123
slot_orders: Slot orders per plot type.
124+
dataset_variable_position: Position of "variable" dim when plotting all Dataset
125+
variables. Default 1 (second, typically color). Use 0 for first, -1 for last.
118126
119127
Yields:
120128
None when used as a context manager.
@@ -136,6 +144,7 @@ def set_options(
136144
"label_include_units": _options.label_include_units,
137145
"label_unit_format": _options.label_unit_format,
138146
"slot_orders": dict(_options.slot_orders),
147+
"dataset_variable_position": _options.dataset_variable_position,
139148
}
140149

141150
# Apply new values (modify in place to keep reference)
@@ -149,6 +158,8 @@ def set_options(
149158
_options.label_unit_format = label_unit_format
150159
if slot_orders is not None:
151160
_options.slot_orders = dict(slot_orders)
161+
if dataset_variable_position is not None:
162+
_options.dataset_variable_position = dataset_variable_position
152163

153164
try:
154165
yield
@@ -159,6 +170,7 @@ def set_options(
159170
_options.label_include_units = old_values["label_include_units"]
160171
_options.label_unit_format = old_values["label_unit_format"]
161172
_options.slot_orders = old_values["slot_orders"]
173+
_options.dataset_variable_position = old_values["dataset_variable_position"]
162174

163175

164176
def notebook(renderer: str = "notebook") -> None:

0 commit comments

Comments
 (0)