5656)
5757from sqlmesh .lsp .rename import prepare_rename , rename_symbol , get_document_highlights
5858from sqlmesh .lsp .uri import URI
59+ from sqlmesh .utils .errors import ConfigError
5960from web .server .api .endpoints .lineage import column_lineage , model_lineage
6061from web .server .api .endpoints .models import get_models
6162from typing import Union
@@ -76,11 +77,14 @@ class ContextLoaded:
7677 lsp_context : LSPContext
7778
7879
80+ type ContextFailedError = str | ConfigError
81+
82+
7983@dataclass
8084class ContextFailed :
8185 """State when context failed to load with an error message."""
8286
83- error_message : str
87+ error_message : ContextFailedError
8488 context : t .Optional [Context ] = None
8589
8690
@@ -137,7 +141,7 @@ def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> Al
137141 except Exception :
138142 pass
139143 try :
140- context = self ._context_get_or_load (uri )
144+ context = self ._context_get_or_load (ls , uri )
141145 return LSPContext .get_completions (context , uri , content )
142146 except Exception as e :
143147 from sqlmesh .lsp .completions import get_sql_completions
@@ -148,21 +152,21 @@ def _custom_render_model(
148152 self , ls : LanguageServer , params : RenderModelRequest
149153 ) -> RenderModelResponse :
150154 uri = URI (params .textDocumentUri )
151- context = self ._context_get_or_load (uri )
155+ context = self ._context_get_or_load (ls , uri )
152156 return RenderModelResponse (models = context .render_model (uri ))
153157
154158 def _custom_all_models_for_render (
155159 self , ls : LanguageServer , params : AllModelsForRenderRequest
156160 ) -> AllModelsForRenderResponse :
157- context = self ._context_get_or_load ()
161+ context = self ._context_get_or_load (ls )
158162 return AllModelsForRenderResponse (models = context .list_of_models_for_rendering ())
159163
160164 def _custom_format_project (
161165 self , ls : LanguageServer , params : FormatProjectRequest
162166 ) -> FormatProjectResponse :
163167 """Format all models in the current project."""
164168 try :
165- context = self ._context_get_or_load ()
169+ context = self ._context_get_or_load (ls )
166170 context .context .format ()
167171 return FormatProjectResponse ()
168172 except Exception as e :
@@ -173,7 +177,7 @@ def _custom_api(
173177 self , ls : LanguageServer , request : ApiRequest
174178 ) -> t .Union [ApiResponseGetModels , ApiResponseGetColumnLineage , ApiResponseGetLineage ]:
175179 ls .log_trace (f"API request: { request } " )
176- context = self ._context_get_or_load ()
180+ context = self ._context_get_or_load (ls )
177181
178182 parsed_url = urllib .parse .urlparse (request .url )
179183 path_parts = parsed_url .path .strip ("/" ).split ("/" )
@@ -244,7 +248,7 @@ def _reload_context_and_publish_diagnostics(
244248 else :
245249 # If there's no context, try to create one from scratch
246250 try :
247- self ._ensure_context_for_document (uri )
251+ self ._ensure_context_for_document (ls , uri )
248252 # If successful, context_state will be ContextLoaded
249253 if isinstance (self .context_state , ContextLoaded ):
250254 ls .show_message (
@@ -335,7 +339,10 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
335339 for ext in ("py" , "yml" , "yaml" ):
336340 config_path = folder_path / f"config.{ ext } "
337341 if config_path .exists ():
338- if self ._create_lsp_context ([folder_path ]):
342+ if self ._create_lsp_context (
343+ ls ,
344+ [folder_path ],
345+ ):
339346 loaded_sqlmesh_message (ls , folder_path )
340347 return # Exit after successfully loading any config
341348 except Exception as e :
@@ -346,7 +353,7 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
346353 @self .server .feature (types .TEXT_DOCUMENT_DID_OPEN )
347354 def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
348355 uri = URI (params .text_document .uri )
349- context = self ._context_get_or_load (uri )
356+ context = self ._context_get_or_load (ls , uri )
350357
351358 # Only publish diagnostics if client doesn't support pull diagnostics
352359 if not self .client_supports_pull_diagnostics :
@@ -368,7 +375,7 @@ def formatting(
368375 """Format the document using SQLMesh `format_model_expressions`."""
369376 try :
370377 uri = URI (params .text_document .uri )
371- context = self ._context_get_or_load (uri )
378+ context = self ._context_get_or_load (ls , uri )
372379 document = ls .workspace .get_text_document (params .text_document .uri )
373380 before = document .source
374381
@@ -412,7 +419,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
412419 """Provide hover information for an object."""
413420 try :
414421 uri = URI (params .text_document .uri )
415- context = self ._context_get_or_load (uri )
422+ context = self ._context_get_or_load (ls , uri )
416423 document = ls .workspace .get_text_document (params .text_document .uri )
417424
418425 references = get_references (context , uri , params .position )
@@ -439,10 +446,10 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
439446 def inlay_hint (
440447 ls : LanguageServer , params : types .InlayHintParams
441448 ) -> t .List [types .InlayHint ]:
442- """Implement type hints for sql columns as inlay hints"""
449+ """Implement type hints for SQL columns as inlay hints"""
443450 try :
444451 uri = URI (params .text_document .uri )
445- context = self ._context_get_or_load (uri )
452+ context = self ._context_get_or_load (ls , uri )
446453
447454 start_line = params .range .start .line
448455 end_line = params .range .end .line
@@ -459,7 +466,7 @@ def goto_definition(
459466 """Jump to an object's definition."""
460467 try :
461468 uri = URI (params .text_document .uri )
462- context = self ._context_get_or_load (uri )
469+ context = self ._context_get_or_load (ls , uri )
463470
464471 references = get_references (context , uri , params .position )
465472 location_links = []
@@ -513,7 +520,7 @@ def find_references(
513520 """Find all references of a symbol (supporting CTEs, models for now)"""
514521 try :
515522 uri = URI (params .text_document .uri )
516- context = self ._context_get_or_load (uri )
523+ context = self ._context_get_or_load (ls , uri )
517524
518525 all_references = get_all_references (context , uri , params .position )
519526
@@ -532,7 +539,7 @@ def prepare_rename_handler(
532539 """Prepare for rename operation by checking if the symbol can be renamed."""
533540 try :
534541 uri = URI (params .text_document .uri )
535- context = self ._context_get_or_load (uri )
542+ context = self ._context_get_or_load (ls , uri )
536543 result = prepare_rename (context , uri , params .position )
537544 return result
538545 except Exception as e :
@@ -546,7 +553,7 @@ def rename_handler(
546553 """Perform rename operation on the symbol at the given position."""
547554 try :
548555 uri = URI (params .text_document .uri )
549- context = self ._context_get_or_load (uri )
556+ context = self ._context_get_or_load (ls , uri )
550557 workspace_edit = rename_symbol (context , uri , params .position , params .new_name )
551558 return workspace_edit
552559 except Exception as e :
@@ -560,7 +567,7 @@ def document_highlight_handler(
560567 """Highlight all occurrences of the symbol at the given position."""
561568 try :
562569 uri = URI (params .text_document .uri )
563- context = self ._context_get_or_load (uri )
570+ context = self ._context_get_or_load (ls , uri )
564571 highlights = get_document_highlights (context , uri , params .position )
565572 return highlights
566573 except Exception as e :
@@ -574,7 +581,7 @@ def diagnostic(
574581 """Handle diagnostic pull requests from the client."""
575582 try :
576583 uri = URI (params .text_document .uri )
577- diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
584+ diagnostics , result_id = self ._get_diagnostics_for_uri (ls , uri )
578585
579586 # Check if client provided a previous result ID
580587 if hasattr (params , "previous_result_id" ) and params .previous_result_id == result_id :
@@ -604,7 +611,7 @@ def workspace_diagnostic(
604611 ) -> types .WorkspaceDiagnosticReport :
605612 """Handle workspace-wide diagnostic pull requests from the client."""
606613 try :
607- context = self ._context_get_or_load ()
614+ context = self ._context_get_or_load (ls )
608615
609616 items : t .List [
610617 t .Union [
@@ -617,7 +624,7 @@ def workspace_diagnostic(
617624 for path , target in context .map .items ():
618625 if isinstance (target , ModelTarget ):
619626 uri = URI .from_path (path )
620- diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
627+ diagnostics , result_id = self ._get_diagnostics_for_uri (ls , uri )
621628
622629 # Check if we have a previous result ID for this file
623630 previous_result_id = None
@@ -662,7 +669,7 @@ def code_action(
662669 try :
663670 ls .log_trace (f"Codeactionrequest: { params } " )
664671 uri = URI (params .text_document .uri )
665- context = self ._context_get_or_load (uri )
672+ context = self ._context_get_or_load (ls , uri )
666673 code_actions = context .get_code_actions (uri , params )
667674 return code_actions
668675
@@ -680,7 +687,7 @@ def completion(
680687 """Handle completion requests from the client."""
681688 try :
682689 uri = URI (params .text_document .uri )
683- context = self ._context_get_or_load (uri )
690+ context = self ._context_get_or_load (ls , uri )
684691
685692 # Get the document content
686693 content = None
@@ -749,30 +756,35 @@ def completion(
749756 get_sql_completions (None , URI (params .text_document .uri ))
750757 return None
751758
752- def _get_diagnostics_for_uri (self , uri : URI ) -> t .Tuple [t .List [types .Diagnostic ], int ]:
759+ def _get_diagnostics_for_uri (
760+ self , ls : LanguageServer , uri : URI
761+ ) -> t .Tuple [t .List [types .Diagnostic ], int ]:
753762 """Get diagnostics for a specific URI, returning (diagnostics, result_id).
754763
755764 Since we no longer track version numbers, we always return 0 as the result_id.
756765 This means pull diagnostics will always fetch fresh results.
757766 """
758767 try :
759- context = self ._context_get_or_load (uri )
768+ context = self ._context_get_or_load (ls , uri )
760769 diagnostics = context .lint_model (uri )
761770 return LSPContext .diagnostics_to_lsp_diagnostics (diagnostics ), 0
762771 except Exception :
763772 return [], 0
764773
765- def _context_get_or_load (self , document_uri : t .Optional [URI ] = None ) -> LSPContext :
774+ def _context_get_or_load (
775+ self , ls : LanguageServer , document_uri : t .Optional [URI ] = None
776+ ) -> LSPContext :
766777 if isinstance (self .context_state , ContextFailed ):
767778 raise RuntimeError (self .context_state .error_message )
768779 if isinstance (self .context_state , NoContext ):
769- self ._ensure_context_for_document (document_uri )
780+ self ._ensure_context_for_document (ls , document_uri )
770781 if not isinstance (self .context_state , ContextLoaded ):
771782 raise RuntimeError ("Context is not loaded" )
772783 return self .context_state .lsp_context
773784
774785 def _ensure_context_for_document (
775786 self ,
787+ ls : LanguageServer ,
776788 document_uri : t .Optional [URI ] = None ,
777789 ) -> None :
778790 """
@@ -784,12 +796,14 @@ def _ensure_context_for_document(
784796 if document_path .is_file () and document_path .suffix in (".sql" , ".py" ):
785797 document_folder = document_path .parent
786798 if document_folder .is_dir ():
787- self ._ensure_context_in_folder (document_folder )
799+ self ._ensure_context_in_folder (ls , document_folder )
788800 return
789801
790- self ._ensure_context_in_folder ()
802+ self ._ensure_context_in_folder (ls )
791803
792- def _ensure_context_in_folder (self , folder_path : t .Optional [Path ] = None ) -> None :
804+ def _ensure_context_in_folder (
805+ self , ls : LanguageServer , folder_path : t .Optional [Path ] = None
806+ ) -> None :
793807 if not isinstance (self .context_state , NoContext ):
794808 return
795809
@@ -798,7 +812,7 @@ def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> Non
798812 for ext in ("py" , "yml" , "yaml" ):
799813 config_path = workspace_folder / f"config.{ ext } "
800814 if config_path .exists ():
801- if self ._create_lsp_context ([workspace_folder ]):
815+ if self ._create_lsp_context (ls , [workspace_folder ]):
802816 return
803817
804818 # Then , check the provided folder recursively
@@ -809,7 +823,7 @@ def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> Non
809823 for ext in ("py" , "yml" , "yaml" ):
810824 config_path = path / f"config.{ ext } "
811825 if config_path .exists ():
812- if self ._create_lsp_context ([path ]):
826+ if self ._create_lsp_context (ls , [path ]):
813827 return
814828
815829 path = path .parent
@@ -821,7 +835,9 @@ def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> Non
821835 + (f" or in { folder_path } " if folder_path else "" )
822836 )
823837
824- def _create_lsp_context (self , paths : t .List [Path ]) -> t .Optional [LSPContext ]:
838+ def _create_lsp_context (
839+ self , ls : LanguageServer , paths : t .List [Path ]
840+ ) -> t .Optional [LSPContext ]:
825841 """Create a new LSPContext instance using the configured context class.
826842
827843 On success, sets self.context_state to ContextLoaded and returns the created context.
@@ -857,16 +873,33 @@ def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]:
857873 types .MessageType .Error ,
858874 )
859875 self .has_raised_loading_error = True
860-
861876 self .server .log_trace (f"Error creating context: { e } " )
862- # Store the error in context state so subsequent requests show the actual error
863- # Try to preserve any partially loaded context if it exists
877+ error_message = e if isinstance (e , ConfigError ) else str (e )
878+ if isinstance (error_message , ConfigError ) and error_message .location is not None :
879+ uri = URI .from_path (error_message .location )
880+ ls .publish_diagnostics (
881+ uri .value ,
882+ [
883+ types .Diagnostic (
884+ range = types .Range (
885+ start = types .Position (line = 0 , character = 0 ),
886+ end = types .Position (line = 0 , character = 0 ),
887+ ),
888+ message = str (error_message ),
889+ severity = types .DiagnosticSeverity .Error ,
890+ )
891+ ],
892+ )
893+
894+ # Store the error in context state such that later requests can
895+ # show the actual error. Try to preserve any partially loaded context
896+ # if it exists
864897 context = None
865898 if isinstance (self .context_state , ContextLoaded ):
866899 context = self .context_state .lsp_context .context
867900 elif isinstance (self .context_state , ContextFailed ) and self .context_state .context :
868901 context = self .context_state .context
869- self .context_state = ContextFailed (error_message = str ( e ) , context = context )
902+ self .context_state = ContextFailed (error_message = error_message , context = context )
870903 return None
871904
872905 @staticmethod
0 commit comments