from collections.abc import Generator from dataclasses import dataclass, field from logging import INFO, basicConfig, debug, error, getLogger, info, warning from re import findall, finditer, fullmatch, match as rematch from time import time from lsprotocol.types import ( INLAY_HINT_RESOLVE, TEXT_DOCUMENT_DID_CHANGE, TEXT_DOCUMENT_DID_OPEN, TEXT_DOCUMENT_DID_SAVE, TEXT_DOCUMENT_DOCUMENT_SYMBOL, TEXT_DOCUMENT_HOVER, TEXT_DOCUMENT_INLAY_HINT, WORKSPACE_INLAY_HINT_REFRESH, WORKSPACE_SEMANTIC_TOKENS_REFRESH, CompletionItem, Diagnostic, DiagnosticSeverity, DidChangeTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentSymbol, DocumentSymbolParams, Hover, HoverParams, InlayHint, InlayHintKind, InlayHintParams, MessageType, Position, Range, SymbolKind, ) from pygls.server import LanguageServer from pygls.workspace import TextDocument from skillls.parsing.iterative import IterativeParser, TokenParser from .cache import Cache URI = str basicConfig(filename="skillls.log", filemode="w", level=INFO) cache: Cache[str, CompletionItem] = Cache() def in_range(what: Position, area: Range) -> bool: return (what >= area.start) and (what <= area.end) def find_end(start: Position, lines: list[str]) -> Position: count = 0 in_str: bool = False last = "" for row, line in enumerate(lines[start.line :]): if row == 0: line = line[start.character :] row += start.character for col, char in enumerate(line[start.character :] if row == 0 else line): match char: case "(": if not in_str: count += 1 case ")": if not in_str: if count > 0: count -= 1 if count == 0: return Position(start.line + row, col) case '"': if not (in_str and last == "\\"): in_str = not in_str case _: last = char last = char error(f"did not fin end for start at {start}") return Position(len(lines), len(lines[-1])) @dataclass(frozen=True) class Environment: range: Range @dataclass(frozen=True) class LetEnvironment(Environment): locals: set[str] = field(default_factory=set) # # @dataclass(frozen=True) # class ProcEnvironment(Environment): # name: str # args: tuple[DocumentSymbol, ...] # kwargs: tuple[DocumentSymbol, ...] # rest: DocumentSymbol | None = None # # @property # def locals(self) -> tuple[DocumentSymbol, ...]: # ret = [*self.args, *self.kwargs] # if self.rest: # ret.append(self.rest) # # return tuple(ret) class SkillLanguageServer(LanguageServer): lets: list[DocumentSymbol] = [] procs: list[DocumentSymbol] = [] defs: list[DocumentSymbol] = [] globals: list[DocumentSymbol] = [] @property def envs(self) -> tuple[DocumentSymbol, ...]: return ( *self.procs, *self.lets, ) def _diagnose_parens(self, doc: TextDocument) -> Generator[Diagnostic, None, None]: open: list[tuple[int, int]] = [] in_str: bool = False last = "" for row, line in enumerate(doc.lines): for col, char in enumerate(line): match char: case "(": if not in_str: open.append((row, col)) case ")": if not in_str: if len(open) > 0: open.pop() else: yield ( Diagnostic( Range( Position(row, col), Position(row, col), ), "unopened ) encountered", ) ) case '"': if not (in_str and last == "\\"): in_str = not in_str case _: last = char last = char if len(open) > 0: for row, col in open: yield ( Diagnostic( Range(Position(row, col), Position(row, col)), "unclosed ) encountered", ) ) def _diagnose_cisms(self, doc: TextDocument) -> Generator[Diagnostic, None, None]: for row, line in enumerate(doc.lines): for m in finditer( r"(?Pprocedure\s+)?([a-zA-Z_][a-zA-Z_0-9]+)\(", line ): if not m.group("proc"): yield Diagnostic( Range(Position(row, m.start()), Position(row, m.end())), f"change `{m.group(2)}(` to `( {m.group(2)}`", DiagnosticSeverity.Hint, ) # for col, char in enumerate(line): # if col > 0: # if fullmatch(r"\w", line[col - 1]) and char == "(": # if m := rematch(r"([a-zA-Z_][a-zA-Z_0-9]*)$", line[:col]): # tok = m.group(1) # r = Range( # Position(row, col - len(tok)), # Position(row, col + 1), # ) # else: # tok = "" # r = Range( # Position(row, col - 1), # Position(row, col + 1), # ) # yield Diagnostic( # r, # f"change `{tok}(` to `( {tok}` [cism]", # DiagnosticSeverity.Hint, # ) # def _diagnose_vars(self, doc: TextDocument) -> Generator[Diagnostic, None, None]: def diagnose(self, doc: TextDocument) -> None: diags: list[Diagnostic] = [] diags.extend(self._diagnose_parens(doc)) diags.extend(self._diagnose_cisms(doc)) self.publish_diagnostics(doc.uri, diags) def parse(self, doc: TextDocument) -> None: self.lets = [] self._parse_let(doc.lines) self.procs = [] self._parse_proc(doc.lines, doc.uri) self.globals = [] self._parse_assigns(doc.lines) def _parse_assigns(self, lines: list[str]) -> None: for row, line in enumerate(lines): for found in finditer(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s+", line): token = found.group(1) token_range = Range( Position(row, found.start()), Position(row, found.start() + len(token)), ) if any( in_range(token_range.start, let.range) and (token in (child.name for child in (let.children or []))) for let in self.lets ): pass else: self.globals.append( DocumentSymbol( token, SymbolKind.Variable, token_range, token_range ) ) def _parse_let(self, lines: list[str]) -> None: active_let: DocumentSymbol for row, line in enumerate(lines): for found in finditer(r"(\(\s*let\s+|\blet\(\s+)\((.*)\)", line): start = Position(row, found.start()) end = find_end(start, lines) children: list[DocumentSymbol] = [] active_let = DocumentSymbol( "let", SymbolKind.Namespace, Range(start, end), Range(start, end), children=children, ) self.lets.append(active_let) offset = len(found.group(1)) + 3 for local_var in finditer( r"([a-zA-Z_][a-zA-Z0-9_]*|\([a-zA-Z_][a-zA-Z0-9_]*\s+.+\))", found.group(2), ): if local_var.group(1).startswith("("): if m := fullmatch( r"\(([a-zA-Z_][a-zA-Z0-9_]*)\s+.+\)", local_var.group(1), ): children.append( DocumentSymbol( m.group(1), SymbolKind.Variable, Range( Position(row, offset + local_var.start() + 1), Position( row, offset + local_var.start() + 1 + len(m.string), ), ), Range( Position(row, offset + local_var.start() + 1), Position( row, offset + local_var.start() + 1 + len(m.group(1)), ), ), ) ) else: assert isinstance(active_let.children, list) active_let.children.append( DocumentSymbol( local_var.group(1), SymbolKind.Variable, Range( Position(row, offset + local_var.start()), Position(row, offset + local_var.end()), ), Range( Position(row, offset + local_var.start()), Position(row, offset + local_var.end()), ), ) ) def _parse_proc(self, lines: list[str], uri: str) -> None: for row, line in enumerate(lines): for found in finditer( r"(\(\s*procedure|\bprocedure\()(\s+)([a-zA-Z_][a-zA-Z0-9_]*)\((.*)\)", line, ): start = Position(row, found.start()) end = find_end(start, lines) if "@option" in found.group(4) and "@key" in found.group(4): self.publish_diagnostics( uri, [ Diagnostic( Range(start, Position(row, len(line))), "`@key` and `@option` used in same definition", severity=DiagnosticSeverity.Error, ) ], ) return args: list[DocumentSymbol] = [] kwargs: list[DocumentSymbol] = [] rest: list[DocumentSymbol] = [] params_start = found.end() - len(found.group(4)) for part in finditer( r"(@(option|key)(\s\(\w+\s+.+\))+|@rest \w+|(\w+\s*))", found.group(4), ): if part.group(1).startswith("@rest"): rest_var_name = part.group(1).split()[1] rest_var_range = Range( Position( row, params_start + part.end() - len(rest_var_name), ), Position(row, params_start + part.end()), ) rest.append( DocumentSymbol( rest_var_name, kind=SymbolKind.Variable, range=rest_var_range, selection_range=rest_var_range, ) ) elif part.group(1).startswith("@"): for kwarg in finditer(r"(\((\w+)\s+[^\)]+\))", part.group(1)): kwargs.append( DocumentSymbol( kwarg.group(2), kind=SymbolKind.Variable, range=Range( Position( row, params_start + part.start() + kwarg.start(), ), Position( row, params_start + part.start() + kwarg.end(), ), ), selection_range=Range( Position( row, params_start + part.start() + kwarg.start(), ), Position( row, params_start + part.start() + kwarg.start() + len(kwarg.group(2)), ), ), ) ) else: for arg in finditer(r"(\w+)", part.group(1)): arg_range = Range( Position( row, params_start + part.start() + arg.start() - 1, ), Position( row, params_start + part.start() + arg.end() - 1, ), ) args.append( DocumentSymbol( arg.group(1), kind=SymbolKind.Variable, range=arg_range, selection_range=arg_range, ) ) self.procs.append( DocumentSymbol( found.group(3), kind=SymbolKind.Function, range=Range(start, end), selection_range=Range(start, Position(row, len(line))), children=args + rest + kwargs, ) ) def _hint_let(self) -> Generator[InlayHint, None, None]: for let in self.lets: if let.children: for child in let.children: yield InlayHint(child.selection_range.end, "|l") def _hint_proc(self) -> Generator[InlayHint, None, None]: for proc in self.procs: warning(proc) if proc.children: for child in proc.children: yield InlayHint(child.selection_range.end, "|l") def _hint_globals(self) -> Generator[InlayHint, None, None]: for glbl in self.globals: yield InlayHint(glbl.selection_range.end, "|g") def hint(self, doc: TextDocument, area: Range) -> list[InlayHint]: hints: list[InlayHint] = [] hints.extend(self._hint_proc()) hints.extend(self._hint_let()) hints.extend(self._hint_globals()) return hints server = SkillLanguageServer("skillls", "v0.3") @server.feature(TEXT_DOCUMENT_DID_SAVE) @server.feature(TEXT_DOCUMENT_DID_OPEN) @server.feature(TEXT_DOCUMENT_DID_CHANGE) def on_open(ls: SkillLanguageServer, params: DidSaveTextDocumentParams) -> None: doc = server.workspace.get_text_document(params.text_document.uri) if not ls.diagnose(doc): ls.parse(doc) ls.lsp.send_request_async(WORKSPACE_INLAY_HINT_REFRESH) @server.feature(TEXT_DOCUMENT_INLAY_HINT) def inlay_hints(ls: SkillLanguageServer, params: InlayHintParams) -> list[InlayHint]: doc = server.workspace.get_text_document(params.text_document.uri) return ls.hint(doc, params.range) @server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL) def doc_symbols( ls: SkillLanguageServer, params: DocumentSymbolParams, ) -> list[DocumentSymbol]: return ls.procs + ls.lets + ls.defs + ls.globals def main(): server.start_io()