From 003c541bb0b144861d8a7d9a5664b212056d48d7 Mon Sep 17 00:00:00 2001 From: jbloom Date: Wed, 6 May 2026 16:38:34 -0700 Subject: [PATCH 1/4] resolve color scale independent so user chart colors survive tree coloring When color_tree_by was set alongside a user chart that had its own color encoding (e.g. a titer plot colored by cell_line), Vega-Lite's default of sharing color scales across concat views merged the tree's color_value:N scale into the user chart's color scale. The user-chart domain was lost and its marks rendered without color. _concat_for_location now passes color="independent" to resolve_scale alongside the existing strain-axis resolution, so the tree's invented color_value:N scale stays isolated from whatever the user has on their chart. Co-Authored-By: Claude Opus 4.7 --- CHANGELOG.md | 11 +++++++++++ src/tree_annotated_plot/_plot.py | 20 ++++++++++++++----- tests/test_color_tree.py | 34 ++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55b3848..f262349 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Fixed + +- When `color_tree_by` is used together with a user chart that has its + own `color` encoding (e.g. a titer plot colored by `cell_line`), the + user chart's marks no longer disappear. The concat container now + resolves the `color` scale as `independent` so the tree's + `color_value:N` scale (with its tree-specific domain) is not merged + with the user chart's color scale. + ## [0.2.1] - 2026-05-06 ### Fixed diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index ae4ad13..fd5a86f 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -276,22 +276,32 @@ def _concat_for_location( user_chart: alt.TopLevelMixin, location: TreeLocation, ) -> alt.HConcatChart | alt.VConcatChart: - """Concat tree and chart in the order implied by the tree's location.""" + """Concat tree and chart in the order implied by the tree's location. + + The strain axis is resolved independent so the tree and chart can use + different scales on that axis (the tree's branch length vs. the chart's + measurement value), while still sharing the orthogonal strain axis. The + `color` scale is also resolved independent: when ``color_tree_by`` is set + the tree panel emits a `color_value:N` color scale with a tree-specific + domain, and Vega-Lite's default of sharing color across concat views + would merge it with any color encoding on the user's chart, hiding + user-chart marks whose color values aren't in the tree's domain. + """ if location == "left": return alt.hconcat(tree_chart, user_chart, spacing=0).resolve_scale( - y="independent" + y="independent", color="independent" ) if location == "right": return alt.hconcat(user_chart, tree_chart, spacing=0).resolve_scale( - y="independent" + y="independent", color="independent" ) if location == "top": return alt.vconcat(tree_chart, user_chart, spacing=0).resolve_scale( - x="independent" + x="independent", color="independent" ) if location == "bottom": return alt.vconcat(user_chart, tree_chart, spacing=0).resolve_scale( - x="independent" + x="independent", color="independent" ) raise ValueError(f"unreachable: tree_location={location!r}") diff --git a/tests/test_color_tree.py b/tests/test_color_tree.py index 0f6e79d..f62fe74 100644 --- a/tests/test_color_tree.py +++ b/tests/test_color_tree.py @@ -663,6 +663,40 @@ def test_color_tree_by_legend_hides_internal_only_unknown(): assert "unknown" in enc["scale"]["domain"] +def test_color_tree_by_color_scale_independent_from_user_chart(): + """When the tree is colored, the concat container must resolve the color + scale as `independent` so the tree's `color_value:N` scale doesn't merge + with a user-chart color encoding (which would hide user marks whose + color values aren't in the tree's domain).""" + df = pd.DataFrame( + { + "strain": ["A", "B", "C", "D"] * 2, + "titer": [1.0, 2.0, 4.0, 8.0, 1.5, 2.5, 4.5, 8.5], + "cell_line": ["X", "X", "X", "X", "Y", "Y", "Y", "Y"], + } + ) + user_chart = ( + alt.Chart(df) + .mark_line(point=True) + .encode( + x="titer:Q", + y=alt.Y("strain:N"), + color=alt.Color("cell_line:N"), + ) + .properties(width=200, height=200) + ) + out = tree_annotated_plot.plot( + _attr_auspice(), + user_chart, + **_kw(), + color_tree_by="subclade", + ) + spec = out.to_dict() + resolve = spec.get("resolve") or {} + scale = resolve.get("scale") or {} + assert scale.get("color") == "independent" + + # ----------------------------------------------------------------------------- # CLI # ----------------------------------------------------------------------------- From dc54d009c66d15be2469a79d4e8a716cfb37003f Mon Sep 17 00:00:00 2001 From: jbloom Date: Thu, 7 May 2026 09:05:31 -0700 Subject: [PATCH 2/4] fix bug in connecting lines in titer chart `mark_line`, `mark_trail`, and `mark_area` now connect points in tree tip order regardless of other encodings on the chart. Previously, a ser chart with an explicit categorical color-scale `domain` rendered with crisscrossing line segments because Vega-Lite's default connection-order heuristic ignored the strain-axis sort. The package now attaches a `calculate` transform that derives a per-row tip rank and points the `order` channel at that field on these marks. Any user-supplied `order` is left in place. --- CHANGELOG.md | 8 + src/tree_annotated_plot/_plot.py | 81 ++++++++++ tests/test_line_order.py | 258 +++++++++++++++++++++++++++++++ 3 files changed, 347 insertions(+) create mode 100644 tests/test_line_order.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f262349..59dc4c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- `mark_line`, `mark_trail`, and `mark_area` now connect points in tree + tip order regardless of other encodings on the chart. Previously, a + user chart with an explicit categorical color-scale `domain` rendered + with crisscrossing line segments because Vega-Lite's default + connection-order heuristic ignored the strain-axis sort. The package + now attaches a `calculate` transform that derives a per-row tip rank + and points the `order` channel at that field on these marks. Any + user-supplied `order` is left in place. - When `color_tree_by` is used together with a user chart that has its own `color` encoding (e.g. a titer plot colored by `cell_line`), the user chart's marks no longer disappear. The concat container now diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index fd5a86f..9ad3fb6 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -233,6 +233,28 @@ def _build( ch.axis = alt.Axis(labels=False, ticks=False, domain=False, title=None) n_hits += 1 _check_walker_hits("strain-axis update", n_hits, len(axis_hits), axis) + + # Pin line/trail/area connection order to tip order via Vega-Lite's + # `order` channel. Without this, an explicit categorical `color` scale + # `domain` (or other encoding choices) can shift Vega-Lite's + # connection-order heuristic away from the strain-axis sort, causing + # lines to crisscross. The order channel's `sort` only accepts + # ascending/descending, so we attach a `calculate` transform that + # computes a per-row tip rank and point `order` at that derived + # quantitative field. User-supplied `order` always wins. + rank_expr = ( + f"indexof({json.dumps(list(tip_names))}, " + f"datum[{json.dumps(config.chart_strain_field)}])" + ) + for node in _iter_connection_order_nodes(new_chart, config.chart_strain_field): + enc = _live_attr(node, "encoding") + if _live_attr(enc, "order") is not None: + continue + existing = _live_attr(node, "transform") or [] + node.transform = list(existing) + [ + {"calculate": rank_expr, "as": _TIP_ORDER_RANK_FIELD} + ] + enc.order = alt.Order(f"{_TIP_ORDER_RANK_FIELD}:Q") hoisted_config, hoisted_other = _pop_toplevel_only_attrs(new_chart) combined = _concat_for_location( @@ -1009,6 +1031,65 @@ def _live_attr(obj: Any, name: str) -> Any: return v +# Vega-Lite marks whose drawing order along the discrete axis is determined +# by the order channel (rule 1 of the connection-order fallback chain). For +# these marks pinning the order channel to a per-row tip-order rank ensures +# the line / trail / area connects strains in tip order regardless of how +# Vega-Lite would otherwise resolve the fallback (e.g. when an explicit +# color-scale `domain` shifts it away from the y-axis sort). +_CONNECTION_ORDER_MARKS = frozenset({"line", "trail", "area"}) + +# Derived field name appended to each connection-order mark's transform +# pipeline. Chosen to be unlikely to collide with user data; private prefix +# makes the intent clear if anyone inspects the rendered spec. +_TIP_ORDER_RANK_FIELD = "_tap_strain_order_idx" + + +def _mark_type(node: Any) -> str | None: + """Return the mark type string for a chart node, or None. + + Handles altair's two mark forms: the plain string (e.g. ``"line"`` from + ``chart.mark_line()``) and the ``MarkDef`` object (e.g. from + ``chart.mark_line(point=True)``). Container nodes (LayerChart, + HConcatChart, etc.) have no mark and return None. + """ + mark = _live_attr(node, "mark") + if mark is None: + return None + if isinstance(mark, str): + return mark + type_ = _live_attr(mark, "type") + return type_ if isinstance(type_, str) else None + + +def _iter_connection_order_nodes(node: Any, chart_strain_field: str) -> Iterator[Any]: + """Yield the live chart node for every node whose mark is in + ``_CONNECTION_ORDER_MARKS`` and whose ``x``/``y`` strain encoding + matches ``chart_strain_field``. + + Mirrors `_iter_strain_axis_channels`'s traversal (hconcat / vconcat / + concat / layer / spec descent). Yields the node itself (not just the + encoding) so the caller can attach both a ``calculate`` transform and + the ``order`` channel. + """ + if _mark_type(node) in _CONNECTION_ORDER_MARKS: + enc = _live_attr(node, "encoding") + if enc is not None: + for channel in ("x", "y"): + ch = _live_attr(enc, channel) + if ch is not None and _channel_field(ch) == chart_strain_field: + yield node + break + for attr in ("hconcat", "vconcat", "concat", "layer"): + sub = _live_attr(node, attr) + if isinstance(sub, list): + for s in sub: + yield from _iter_connection_order_nodes(s, chart_strain_field) + spec = _live_attr(node, "spec") + if spec is not None: + yield from _iter_connection_order_nodes(spec, chart_strain_field) + + def _pop_toplevel_only_attrs( chart: alt.TopLevelMixin, ) -> tuple[Any, dict]: diff --git a/tests/test_line_order.py b/tests/test_line_order.py new file mode 100644 index 0000000..ee4cde4 --- /dev/null +++ b/tests/test_line_order.py @@ -0,0 +1,258 @@ +"""Connection order on `mark_line` / `mark_trail` / `mark_area` follows tip order. + +Regression: when the user chart had `mark_line` with strain on a discrete +axis, lines connected points in tip order most of the time — but adding an +explicit categorical color-scale `domain` shifted Vega-Lite's default +connection-order heuristic away from the strain-axis `sort`, causing lines +to crisscross. The package now sets the `order` channel to a per-row +tip-rank field on these marks (unless the user has set their own `order`), +pinning rule 1 of the Vega-Lite fallback chain. + +The tip-rank is computed via a `calculate` transform that runs `indexof` +against the tip-names list — Vega-Lite's `order` channel `sort` only +accepts ascending/descending, so a derived quantitative field is the +direct way to express custom ordering. +""" + +from __future__ import annotations + +from typing import Any + +import altair as alt +import pandas as pd +import pytest + +import tree_annotated_plot +from tree_annotated_plot._plot import _TIP_ORDER_RANK_FIELD + + +def _synthetic_auspice() -> dict: + """4-tip Auspice tree; tip layout order is [A, B, C, D].""" + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "int1", + "node_attrs": {"div": 0.01}, + "children": [ + {"name": "A", "node_attrs": {"div": 0.04}}, + {"name": "B", "node_attrs": {"div": 0.05}}, + ], + }, + { + "name": "int2", + "node_attrs": {"div": 0.02}, + "children": [ + {"name": "C", "node_attrs": {"div": 0.03}}, + {"name": "D", "node_attrs": {"div": 0.06}}, + ], + }, + ], + }, + } + + +TIP_ORDER = ["A", "B", "C", "D"] + + +def _df() -> pd.DataFrame: + return pd.DataFrame( + [ + {"strain": s, "serum": serum, "titer": v} + for s, serum, v in [ + ("A", "s1", 100), + ("B", "s1", 200), + ("C", "s1", 400), + ("D", "s1", 800), + ("A", "s2", 150), + ("B", "s2", 250), + ("C", "s2", 350), + ("D", "s2", 700), + ] + ] + ) + + +def _plot_user_panel(chart: alt.TopLevelMixin) -> dict: + """Run plot() with default settings and return the user-panel spec dict.""" + out = tree_annotated_plot.plot( + _synthetic_auspice(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + ) + # Default tree_location for y-axis strain is "left" — user panel is hconcat[1]. + return out.to_dict()["hconcat"][1] + + +def _find_order_encodings(spec: Any) -> list[dict]: + """Walk a spec dict and return every `encoding.order` dict found.""" + found: list[dict] = [] + + def walk(node: Any) -> None: + if isinstance(node, dict): + enc = node.get("encoding") + if isinstance(enc, dict) and "order" in enc: + found.append(enc["order"]) + for k, v in node.items(): + if k != "encoding": + walk(v) + elif isinstance(node, list): + for item in node: + walk(item) + + walk(spec) + return found + + +def _find_tip_rank_transforms(spec: Any) -> list[dict]: + """Walk a spec dict and return every calculate transform that computes + the package's tip-rank derived field.""" + found: list[dict] = [] + + def walk(node: Any) -> None: + if isinstance(node, dict): + transforms = node.get("transform") + if isinstance(transforms, list): + for t in transforms: + if isinstance(t, dict) and t.get("as") == _TIP_ORDER_RANK_FIELD: + found.append(t) + for v in node.values(): + walk(v) + elif isinstance(node, list): + for item in node: + walk(item) + + walk(spec) + return found + + +def _assert_tip_rank_wired(spec: dict) -> None: + """Spec contains both an order-channel pointing at the tip-rank field + and a calculate transform that derives that field from the tip names.""" + orders = _find_order_encodings(spec) + assert len(orders) == 1, f"expected one order encoding, got {orders}" + assert orders[0]["field"] == _TIP_ORDER_RANK_FIELD + assert orders[0]["type"] == "quantitative" + transforms = _find_tip_rank_transforms(spec) + assert len(transforms) == 1, f"expected one tip-rank transform, got {transforms}" + expr = transforms[0]["calculate"] + assert "indexof" in expr + # The expression embeds the tip names list verbatim. + for tip in TIP_ORDER: + assert f'"{tip}"' in expr, f"tip {tip!r} not in calculate expression {expr!r}" + + +def test_mark_line_with_color_domain_gets_order_in_tip_order(): + """The original bug: explicit color-scale `domain` no longer crisscrosses lines.""" + chart = ( + alt.Chart(_df()) + .mark_line(point=True) + .encode( + x=alt.X("titer:Q", scale=alt.Scale(type="log")), + y=alt.Y("strain:N"), + color=alt.Color( + "serum:N", + scale=alt.Scale(domain=["s1", "s2"], range=["red", "blue"]), + ), + ) + .properties(width=300, height=200) + ) + _assert_tip_rank_wired(_plot_user_panel(chart)) + + +def test_user_supplied_order_is_preserved(): + """User-supplied `order` always wins; the package never overwrites it.""" + chart = ( + alt.Chart(_df()) + .mark_line() + .encode( + x=alt.X("titer:Q"), + y=alt.Y("strain:N"), + color="serum:N", + order=alt.Order("titer:Q", sort="ascending"), + ) + .properties(width=300, height=200) + ) + spec = _plot_user_panel(chart) + orders = _find_order_encodings(spec) + assert len(orders) == 1 + assert orders[0]["field"] == "titer" + # The package did not attach its tip-rank transform either. + assert _find_tip_rank_transforms(spec) == [] + + +@pytest.mark.parametrize("mark_method", ["mark_line", "mark_area", "mark_trail"]) +def test_connection_order_marks_get_order_in_tip_order(mark_method: str): + """All three connection-order marks (line / area / trail) get `order` injected.""" + base = alt.Chart(_df()).encode( + x=alt.X("titer:Q"), + y=alt.Y("strain:N"), + color="serum:N", + ) + chart = getattr(base, mark_method)().properties(width=300, height=200) + _assert_tip_rank_wired(_plot_user_panel(chart)) + + +def test_layered_chart_only_line_layer_gets_order(): + """Mixed-mark layers: line layer gets `order`, circle layer is untouched.""" + base = alt.Chart(_df()).encode( + x=alt.X("titer:Q"), + y=alt.Y("strain:N"), + color="serum:N", + ) + layered = (base.mark_line() + base.mark_circle()).properties(width=300, height=200) + _assert_tip_rank_wired(_plot_user_panel(layered)) + + +def test_no_connection_order_marks_means_no_order_injected(): + """A chart with only point/circle/etc. marks is a no-op for the walker.""" + chart = ( + alt.Chart(_df()) + .mark_circle() + .encode( + x=alt.X("titer:Q"), + y=alt.Y("strain:N"), + color="serum:N", + ) + .properties(width=300, height=200) + ) + spec = _plot_user_panel(chart) + assert _find_order_encodings(spec) == [] + assert _find_tip_rank_transforms(spec) == [] + + +def test_faceted_chart_gets_order_injection(): + """The walker descends into the inner spec of a faceted chart.""" + df_facet = pd.DataFrame( + [ + {"strain": s, "serum": serum, "panel": p, "titer": v} + for s, serum, p, v in [ + ("A", "s1", "P1", 100), + ("B", "s1", "P1", 200), + ("C", "s1", "P1", 400), + ("D", "s1", "P1", 800), + ("A", "s1", "P2", 150), + ("B", "s1", "P2", 250), + ("C", "s1", "P2", 350), + ("D", "s1", "P2", 700), + ] + ] + ) + base = ( + alt.Chart(df_facet) + .mark_line() + .encode( + x=alt.X("titer:Q"), + y=alt.Y("strain:N"), + color="serum:N", + ) + .properties(width=300, height=200) + ) + faceted = base.facet(facet="panel:N", columns=2) + _assert_tip_rank_wired(_plot_user_panel(faceted)) From a7a16bed9aadf6cc029fb0a8af5fafaccee33d3a Mon Sep 17 00:00:00 2001 From: jbloom Date: Thu, 7 May 2026 11:45:56 -0700 Subject: [PATCH 3/4] adjust placement of tree color legend in examples --- docs/examples.md | 17 +++++++++++++---- scripts/generate_docs_assets.py | 6 ++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/docs/examples.md b/docs/examples.md index f11f009..0df0d6d 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -105,8 +105,12 @@ edge. The H3N2 example above is rendered with `color_tree_by="subclade"`, which colors the tree's branches and tip circles by the -`node_attrs.subclade` value at each node and adds a categorical legend -below the plot. See "Color the tree" below for the full set of options. +`node_attrs.subclade` value at each node. The package's default places +the categorical legend at the bottom of the combined plot; here we pass +`tree_color_legend_format={"orient": "left"}` to push it to the left +side instead, so it doesn't compete with the cohort-selection legend +already sitting below the chart. See "Color the tree" below for the +full set of options. ### Optional: connect leaders all the way to the labels @@ -156,8 +160,11 @@ Colors match what you'd see on the Nextstrain view of the same tree — either from the JSON's palette information when the build provides it, or from the same default palette Auspice uses when it doesn't. Categories are ordered by descending frequency in both cases. Missing -values render in gray, and the legend is drawn at the bottom of the -combined plot. +values render in gray. By default the legend is drawn at the bottom of +the combined plot; the H3N2 examples in this section pass +`tree_color_legend_format={"orient": "left"}` to move it to the left +instead, since the bottom edge already carries the cohort-selection +legend. The example below colors the same H3N2 chart by genotype at HA1 site 158, which has two mutations in the tree (`N158K`, `N158D`) and @@ -277,6 +284,7 @@ tree-annotated-plot \ --scale-bar \ --branch-length-units substitutions \ --color-tree-by subclade \ + --tree-color-legend-format '{"orient":"left"}' \ --output examples/data/h3n2_combined.json ``` @@ -295,6 +303,7 @@ out = tree_annotated_plot.plot( scale_bar=True, branch_length_units="substitutions", color_tree_by="subclade", + tree_color_legend_format={"orient": "left"}, ) ``` diff --git a/scripts/generate_docs_assets.py b/scripts/generate_docs_assets.py index 1b9beec..89f18f9 100644 --- a/scripts/generate_docs_assets.py +++ b/scripts/generate_docs_assets.py @@ -122,7 +122,11 @@ def _render_kikawa() -> None: # Color H3N2 by subclade so the docs SVG matches what users see # on Nextstrain. The Auspice JSON's meta.colorings.subclade has # no `scale` defined, so colors come from the default palette. + # Place the subclade legend on the left so it doesn't compete + # with the cohort-selection legend that already sits at the + # bottom of the chart. plot_kwargs["color_tree_by"] = "subclade" + plot_kwargs["tree_color_legend_format"] = {"orient": "left"} out = tree_annotated_plot.plot( DATA_DIR / f"flu-seqneut-2025to2026_{subtype}.json", chart, @@ -155,6 +159,7 @@ def _render_kikawa() -> None: strain_label_font_size=9, shift_tree_loc=60, color_tree_by="subclade", + tree_color_legend_format={"orient": "left"}, ) _save_pair(out, "h3n2_combined_label_connect") @@ -180,6 +185,7 @@ def _render_kikawa() -> None: scale_bar=True, branch_length_units="substitutions", color_tree_by="genotype:HA1:158", + tree_color_legend_format={"orient": "left"}, ) _save_pair(out, "h3n2_combined_genotype_158") From fcbfd01190b5a90ec2b80a1b0183a37605e9e399 Mon Sep 17 00:00:00 2001 From: jbloom Date: Sat, 9 May 2026 07:06:55 -0700 Subject: [PATCH 4/4] increment version to 0.2.2 --- CHANGELOG.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59dc4c7..36aeaeb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [0.2.2] - 2026-05-09 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index 1f8bac6..74e12ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tree-annotated-plot" -version = "0.2.1" +version = "0.2.2" description = "Annotate the axis of an Altair / Vega-Lite plot with a phylogenetic tree." readme = "README.md" requires-python = ">=3.13"