Skip to content

Commit 36e6627

Browse files
committed
feat: pass location of errors to lsp
1 parent d1dc55c commit 36e6627

4 files changed

Lines changed: 80 additions & 42 deletions

File tree

sqlmesh/core/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
10051005
package=package,
10061006
)
10071007
except Exception as e:
1008-
raise ConfigError(f"Failed to load macro file: {path}", e)
1008+
raise ConfigError(f"Failed to load macro file: {e}", path)
10091009

10101010
self._macros_max_mtime = macros_max_mtime
10111011

sqlmesh/core/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def make_python_env(
4141
macros: MacroRegistry,
4242
variables: t.Optional[t.Dict[str, t.Any]] = None,
4343
used_variables: t.Optional[t.Set[str]] = None,
44-
path: t.Optional[str | Path] = None,
44+
path: t.Optional[Path] = None,
4545
python_env: t.Optional[t.Dict[str, Executable]] = None,
4646
strict_resolution: bool = True,
4747
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,

sqlmesh/lsp/main.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
)
5757
from sqlmesh.lsp.rename import prepare_rename, rename_symbol, get_document_highlights
5858
from sqlmesh.lsp.uri import URI
59+
from sqlmesh.utils.errors import ConfigError
5960
from web.server.api.endpoints.lineage import column_lineage, model_lineage
6061
from web.server.api.endpoints.models import get_models
6162
from typing import Union
@@ -76,11 +77,14 @@ class ContextLoaded:
7677
lsp_context: LSPContext
7778

7879

80+
type ContextFailedError = str | ConfigError
81+
82+
7983
@dataclass
8084
class ContextFailed:
8185
"""State when context failed to load with an error message."""
8286

83-
error_message: str
87+
error_message: ContextFailedError
8488
context: t.Optional[Context] = None
8589

8690

@@ -137,7 +141,7 @@ def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> Al
137141
except Exception:
138142
pass
139143
try:
140-
context = self._context_get_or_load(uri)
144+
context = self._context_get_or_load(ls, uri)
141145
return LSPContext.get_completions(context, uri, content)
142146
except Exception as e:
143147
from sqlmesh.lsp.completions import get_sql_completions
@@ -148,21 +152,21 @@ def _custom_render_model(
148152
self, ls: LanguageServer, params: RenderModelRequest
149153
) -> RenderModelResponse:
150154
uri = URI(params.textDocumentUri)
151-
context = self._context_get_or_load(uri)
155+
context = self._context_get_or_load(ls, uri)
152156
return RenderModelResponse(models=context.render_model(uri))
153157

154158
def _custom_all_models_for_render(
155159
self, ls: LanguageServer, params: AllModelsForRenderRequest
156160
) -> AllModelsForRenderResponse:
157-
context = self._context_get_or_load()
161+
context = self._context_get_or_load(ls)
158162
return AllModelsForRenderResponse(models=context.list_of_models_for_rendering())
159163

160164
def _custom_format_project(
161165
self, ls: LanguageServer, params: FormatProjectRequest
162166
) -> FormatProjectResponse:
163167
"""Format all models in the current project."""
164168
try:
165-
context = self._context_get_or_load()
169+
context = self._context_get_or_load(ls)
166170
context.context.format()
167171
return FormatProjectResponse()
168172
except Exception as e:
@@ -173,7 +177,7 @@ def _custom_api(
173177
self, ls: LanguageServer, request: ApiRequest
174178
) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]:
175179
ls.log_trace(f"API request: {request}")
176-
context = self._context_get_or_load()
180+
context = self._context_get_or_load(ls)
177181

178182
parsed_url = urllib.parse.urlparse(request.url)
179183
path_parts = parsed_url.path.strip("/").split("/")
@@ -244,7 +248,7 @@ def _reload_context_and_publish_diagnostics(
244248
else:
245249
# If there's no context, try to create one from scratch
246250
try:
247-
self._ensure_context_for_document(uri)
251+
self._ensure_context_for_document(ls, uri)
248252
# If successful, context_state will be ContextLoaded
249253
if isinstance(self.context_state, ContextLoaded):
250254
ls.show_message(
@@ -335,7 +339,10 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
335339
for ext in ("py", "yml", "yaml"):
336340
config_path = folder_path / f"config.{ext}"
337341
if config_path.exists():
338-
if self._create_lsp_context([folder_path]):
342+
if self._create_lsp_context(
343+
ls,
344+
[folder_path],
345+
):
339346
loaded_sqlmesh_message(ls, folder_path)
340347
return # Exit after successfully loading any config
341348
except Exception as e:
@@ -346,7 +353,7 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
346353
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
347354
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
348355
uri = URI(params.text_document.uri)
349-
context = self._context_get_or_load(uri)
356+
context = self._context_get_or_load(ls, uri)
350357

351358
# Only publish diagnostics if client doesn't support pull diagnostics
352359
if not self.client_supports_pull_diagnostics:
@@ -368,7 +375,7 @@ def formatting(
368375
"""Format the document using SQLMesh `format_model_expressions`."""
369376
try:
370377
uri = URI(params.text_document.uri)
371-
context = self._context_get_or_load(uri)
378+
context = self._context_get_or_load(ls, uri)
372379
document = ls.workspace.get_text_document(params.text_document.uri)
373380
before = document.source
374381

@@ -412,7 +419,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
412419
"""Provide hover information for an object."""
413420
try:
414421
uri = URI(params.text_document.uri)
415-
context = self._context_get_or_load(uri)
422+
context = self._context_get_or_load(ls, uri)
416423
document = ls.workspace.get_text_document(params.text_document.uri)
417424

418425
references = get_references(context, uri, params.position)
@@ -439,10 +446,10 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
439446
def inlay_hint(
440447
ls: LanguageServer, params: types.InlayHintParams
441448
) -> t.List[types.InlayHint]:
442-
"""Implement type hints for sql columns as inlay hints"""
449+
"""Implement type hints for SQL columns as inlay hints"""
443450
try:
444451
uri = URI(params.text_document.uri)
445-
context = self._context_get_or_load(uri)
452+
context = self._context_get_or_load(ls, uri)
446453

447454
start_line = params.range.start.line
448455
end_line = params.range.end.line
@@ -459,7 +466,7 @@ def goto_definition(
459466
"""Jump to an object's definition."""
460467
try:
461468
uri = URI(params.text_document.uri)
462-
context = self._context_get_or_load(uri)
469+
context = self._context_get_or_load(ls, uri)
463470

464471
references = get_references(context, uri, params.position)
465472
location_links = []
@@ -513,7 +520,7 @@ def find_references(
513520
"""Find all references of a symbol (supporting CTEs, models for now)"""
514521
try:
515522
uri = URI(params.text_document.uri)
516-
context = self._context_get_or_load(uri)
523+
context = self._context_get_or_load(ls, uri)
517524

518525
all_references = get_all_references(context, uri, params.position)
519526

@@ -532,7 +539,7 @@ def prepare_rename_handler(
532539
"""Prepare for rename operation by checking if the symbol can be renamed."""
533540
try:
534541
uri = URI(params.text_document.uri)
535-
context = self._context_get_or_load(uri)
542+
context = self._context_get_or_load(ls, uri)
536543
result = prepare_rename(context, uri, params.position)
537544
return result
538545
except Exception as e:
@@ -546,7 +553,7 @@ def rename_handler(
546553
"""Perform rename operation on the symbol at the given position."""
547554
try:
548555
uri = URI(params.text_document.uri)
549-
context = self._context_get_or_load(uri)
556+
context = self._context_get_or_load(ls, uri)
550557
workspace_edit = rename_symbol(context, uri, params.position, params.new_name)
551558
return workspace_edit
552559
except Exception as e:
@@ -560,7 +567,7 @@ def document_highlight_handler(
560567
"""Highlight all occurrences of the symbol at the given position."""
561568
try:
562569
uri = URI(params.text_document.uri)
563-
context = self._context_get_or_load(uri)
570+
context = self._context_get_or_load(ls, uri)
564571
highlights = get_document_highlights(context, uri, params.position)
565572
return highlights
566573
except Exception as e:
@@ -574,7 +581,7 @@ def diagnostic(
574581
"""Handle diagnostic pull requests from the client."""
575582
try:
576583
uri = URI(params.text_document.uri)
577-
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
584+
diagnostics, result_id = self._get_diagnostics_for_uri(ls, uri)
578585

579586
# Check if client provided a previous result ID
580587
if hasattr(params, "previous_result_id") and params.previous_result_id == result_id:
@@ -604,7 +611,7 @@ def workspace_diagnostic(
604611
) -> types.WorkspaceDiagnosticReport:
605612
"""Handle workspace-wide diagnostic pull requests from the client."""
606613
try:
607-
context = self._context_get_or_load()
614+
context = self._context_get_or_load(ls)
608615

609616
items: t.List[
610617
t.Union[
@@ -617,7 +624,7 @@ def workspace_diagnostic(
617624
for path, target in context.map.items():
618625
if isinstance(target, ModelTarget):
619626
uri = URI.from_path(path)
620-
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
627+
diagnostics, result_id = self._get_diagnostics_for_uri(ls, uri)
621628

622629
# Check if we have a previous result ID for this file
623630
previous_result_id = None
@@ -662,7 +669,7 @@ def code_action(
662669
try:
663670
ls.log_trace(f"Codeactionrequest: {params}")
664671
uri = URI(params.text_document.uri)
665-
context = self._context_get_or_load(uri)
672+
context = self._context_get_or_load(ls, uri)
666673
code_actions = context.get_code_actions(uri, params)
667674
return code_actions
668675

@@ -680,7 +687,7 @@ def completion(
680687
"""Handle completion requests from the client."""
681688
try:
682689
uri = URI(params.text_document.uri)
683-
context = self._context_get_or_load(uri)
690+
context = self._context_get_or_load(ls, uri)
684691

685692
# Get the document content
686693
content = None
@@ -749,30 +756,35 @@ def completion(
749756
get_sql_completions(None, URI(params.text_document.uri))
750757
return None
751758

752-
def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], int]:
759+
def _get_diagnostics_for_uri(
760+
self, ls: LanguageServer, uri: URI
761+
) -> t.Tuple[t.List[types.Diagnostic], int]:
753762
"""Get diagnostics for a specific URI, returning (diagnostics, result_id).
754763
755764
Since we no longer track version numbers, we always return 0 as the result_id.
756765
This means pull diagnostics will always fetch fresh results.
757766
"""
758767
try:
759-
context = self._context_get_or_load(uri)
768+
context = self._context_get_or_load(ls, uri)
760769
diagnostics = context.lint_model(uri)
761770
return LSPContext.diagnostics_to_lsp_diagnostics(diagnostics), 0
762771
except Exception:
763772
return [], 0
764773

765-
def _context_get_or_load(self, document_uri: t.Optional[URI] = None) -> LSPContext:
774+
def _context_get_or_load(
775+
self, ls: LanguageServer, document_uri: t.Optional[URI] = None
776+
) -> LSPContext:
766777
if isinstance(self.context_state, ContextFailed):
767778
raise RuntimeError(self.context_state.error_message)
768779
if isinstance(self.context_state, NoContext):
769-
self._ensure_context_for_document(document_uri)
780+
self._ensure_context_for_document(ls, document_uri)
770781
if not isinstance(self.context_state, ContextLoaded):
771782
raise RuntimeError("Context is not loaded")
772783
return self.context_state.lsp_context
773784

774785
def _ensure_context_for_document(
775786
self,
787+
ls: LanguageServer,
776788
document_uri: t.Optional[URI] = None,
777789
) -> None:
778790
"""
@@ -784,12 +796,14 @@ def _ensure_context_for_document(
784796
if document_path.is_file() and document_path.suffix in (".sql", ".py"):
785797
document_folder = document_path.parent
786798
if document_folder.is_dir():
787-
self._ensure_context_in_folder(document_folder)
799+
self._ensure_context_in_folder(ls, document_folder)
788800
return
789801

790-
self._ensure_context_in_folder()
802+
self._ensure_context_in_folder(ls)
791803

792-
def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> None:
804+
def _ensure_context_in_folder(
805+
self, ls: LanguageServer, folder_path: t.Optional[Path] = None
806+
) -> None:
793807
if not isinstance(self.context_state, NoContext):
794808
return
795809

@@ -798,7 +812,7 @@ def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> Non
798812
for ext in ("py", "yml", "yaml"):
799813
config_path = workspace_folder / f"config.{ext}"
800814
if config_path.exists():
801-
if self._create_lsp_context([workspace_folder]):
815+
if self._create_lsp_context(ls, [workspace_folder]):
802816
return
803817

804818
# Then , check the provided folder recursively
@@ -809,7 +823,7 @@ def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> Non
809823
for ext in ("py", "yml", "yaml"):
810824
config_path = path / f"config.{ext}"
811825
if config_path.exists():
812-
if self._create_lsp_context([path]):
826+
if self._create_lsp_context(ls, [path]):
813827
return
814828

815829
path = path.parent
@@ -821,7 +835,9 @@ def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> Non
821835
+ (f" or in {folder_path}" if folder_path else "")
822836
)
823837

824-
def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]:
838+
def _create_lsp_context(
839+
self, ls: LanguageServer, paths: t.List[Path]
840+
) -> t.Optional[LSPContext]:
825841
"""Create a new LSPContext instance using the configured context class.
826842
827843
On success, sets self.context_state to ContextLoaded and returns the created context.
@@ -857,16 +873,33 @@ def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]:
857873
types.MessageType.Error,
858874
)
859875
self.has_raised_loading_error = True
860-
861876
self.server.log_trace(f"Error creating context: {e}")
862-
# Store the error in context state so subsequent requests show the actual error
863-
# Try to preserve any partially loaded context if it exists
877+
error_message = e if isinstance(e, ConfigError) else str(e)
878+
if isinstance(error_message, ConfigError) and error_message.location is not None:
879+
uri = URI.from_path(error_message.location)
880+
ls.publish_diagnostics(
881+
uri.value,
882+
[
883+
types.Diagnostic(
884+
range=types.Range(
885+
start=types.Position(line=0, character=0),
886+
end=types.Position(line=0, character=0),
887+
),
888+
message=str(error_message),
889+
severity=types.DiagnosticSeverity.Error,
890+
)
891+
],
892+
)
893+
894+
# Store the error in context state such that later requests can
895+
# show the actual error. Try to preserve any partially loaded context
896+
# if it exists
864897
context = None
865898
if isinstance(self.context_state, ContextLoaded):
866899
context = self.context_state.lsp_context.context
867900
elif isinstance(self.context_state, ContextFailed) and self.context_state.context:
868901
context = self.context_state.context
869-
self.context_state = ContextFailed(error_message=str(e), context=context)
902+
self.context_state = ContextFailed(error_message=error_message, context=context)
870903
return None
871904

872905
@staticmethod

sqlmesh/utils/errors.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ class SQLMeshError(Exception):
2424

2525

2626
class ConfigError(SQLMeshError):
27-
pass
27+
location: t.Optional[Path] = None
28+
29+
def __init__(self, message: str | Exception, location: t.Optional[Path] = None) -> None:
30+
super().__init__(message)
31+
if location:
32+
self.location = Path(location) if isinstance(location, str) else location
2833

2934

3035
class MissingDependencyError(SQLMeshError):
@@ -188,12 +193,12 @@ class SignalEvalError(SQLMeshError):
188193

189194
def raise_config_error(
190195
msg: str,
191-
location: t.Optional[str | Path] = None,
196+
location: t.Optional[Path] = None,
192197
error_type: t.Type[ConfigError] = ConfigError,
193198
) -> None:
194199
if location:
195200
raise error_type(f"{msg} at '{location}'")
196-
raise error_type(msg)
201+
raise error_type(msg, location=location)
197202

198203

199204
def raise_for_status(response: Response) -> None:

0 commit comments

Comments
 (0)