from collections.abc import Callable, Generator, Sequence from dataclasses import dataclass, field from difflib import Differ from itertools import chain from logging import DEBUG, INFO, basicConfig, debug, error, getLogger, info, warning from re import findall, finditer, fullmatch, match as rematch import re from time import time from cattrs import Converter from lsprotocol.types import ( INLAY_HINT_RESOLVE, TEXT_DOCUMENT_DID_CHANGE, TEXT_DOCUMENT_DID_OPEN, TEXT_DOCUMENT_DID_SAVE, TEXT_DOCUMENT_DOCUMENT_SYMBOL, TEXT_DOCUMENT_HOVER, TEXT_DOCUMENT_INLAY_HINT, WORKSPACE_INLAY_HINT_REFRESH, WORKSPACE_SEMANTIC_TOKENS_REFRESH, CompletionItem, Diagnostic, DiagnosticSeverity, DidChangeTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentSymbol, DocumentSymbolParams, Hover, HoverParams, InlayHint, InlayHintKind, InlayHintParams, MessageType, NotebookDocumentSyncOptions, Position, Range, SymbolKind, TextDocumentContentChangeEvent, TextDocumentContentChangeEvent_Type1, TextDocumentSyncKind, ) from pygls.protocol import LanguageServerProtocol, default_converter from tree_sitter_skill import language as skill_lang from tree_sitter import Language, Node, Parser, Query, Tree from pygls.server import LanguageServer from pygls.workspace import TextDocument from .cache import Cache SKILL_LANG = Language(skill_lang()) SKILL_PARSER = Parser(SKILL_LANG) URI = str basicConfig( filename="skillls.log", filemode="w", level=DEBUG, format="%(asctime)s [%(levelname)s]: %(message)s", ) logger = getLogger() cache: Cache[str, CompletionItem] = Cache() def in_range(what: Position, area: Range) -> bool: return (what >= area.start) and (what <= area.end) def find_end(start: Position, lines: list[str]) -> Position: count = 0 in_str: bool = False last = "" for row, line in enumerate(lines[start.line :]): if row == 0: line = line[start.character :] row += start.character for col, char in enumerate(line[start.character :] if row == 0 else line): match char: case "(": if not in_str: count += 1 case ")": if not in_str: if count > 0: count -= 1 if count == 0: return Position(start.line + row, col) case '"': if not (in_str and last == "\\"): in_str = not in_str case _: last = char last = char error(f"did not fin end for start at {start}") return Position(len(lines), len(lines[-1])) @dataclass(frozen=True) class Environment: range: Range @dataclass(frozen=True) class LetEnvironment(Environment): locals: set[str] = field(default_factory=set) def offset_range(range: Range, lines: int, cols: int = 0) -> Range: return Range( Position( range.start.line + lines, range.start.character + cols, ), Position( range.end.line + lines, range.end.character + cols, ), ) # # @dataclass(frozen=True) # class ProcEnvironment(Environment): # name: str # args: tuple[DocumentSymbol, ...] # kwargs: tuple[DocumentSymbol, ...] # rest: DocumentSymbol | None = None # # @property # def locals(self) -> tuple[DocumentSymbol, ...]: # ret = [*self.args, *self.kwargs] # if self.rest: # ret.append(self.rest) # # return tuple(ret) class SkillLanguageServer(LanguageServer): contents: dict[str, TextDocument] trees: dict[str, Tree] def __init__( self, name: str, version: str, loop=None, protocol_cls: type[LanguageServerProtocol] = LanguageServerProtocol, converter_factory: Callable[[], Converter] = default_converter, text_document_sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental, notebook_document_sync: NotebookDocumentSyncOptions | None = None, max_workers: int = 2, ): super().__init__( name, version, loop, protocol_cls, converter_factory, text_document_sync_kind, notebook_document_sync, max_workers, ) self.trees = {} self.contents = {} def parse(self, doc: TextDocument) -> None: parsed = SKILL_PARSER.parse(doc.source.encode("utf8"), encoding="utf8") self.trees[doc.uri] = parsed self.contents[doc.uri] = doc def update(self, uri: str, changes: list[TextDocumentContentChangeEvent]) -> None: for change in changes: if isinstance(change, TextDocumentContentChangeEvent_Type1): logger.debug(f"updating {change.range}") change_fixed = TextDocumentContentChangeEvent_Type1( offset_range(change.range, -1), change.text, change.range_length, ) old = self.contents[uri].lines self.contents[uri].apply_change(change) d = Differ() logger.debug("".join(d.compare(old, self.contents[uri].lines))) else: pass self.trees[uri] = SKILL_PARSER.parse( self.contents[uri].source.encode("utf8"), old_tree=self.trees[uri], ) def _get_leaves(self, node: Node) -> list[Node]: if node.children: return [l for child in node.children for l in self._get_leaves(child)] return [node] def _diagnose_errors(self, uri: str) -> list[Diagnostic]: diags: list[Diagnostic] = [] q = SKILL_LANG.query("(ERROR) @error") nodes = ( q.captures(self.trees[uri].root_node)["error"] if self.trees.get(uri) else [] ) for node in nodes: if node.type == "ERROR": logger.error(node) logger.error(node.range) content = node.text.decode("utf8") if node.text else "" range = Range( Position(*node.range.start_point), Position(*node.range.end_point) ) if "UNEXPECTED" in str(node): msg = f"unexpected '{content}'" else: msg = str() diags.append( Diagnostic( range, msg, severity=DiagnosticSeverity.Error, ), ) return diags def diagnose(self, uri: str) -> list[Diagnostic]: diags: list[Diagnostic] = [] diags.extend(self._diagnose_errors(uri)) return diags server = SkillLanguageServer("skillls", "v0.3") # @server.feature(TEXT_DOCUMENT_DID_SAVE) @server.feature(TEXT_DOCUMENT_DID_OPEN) def on_open(ls: SkillLanguageServer, params: DidSaveTextDocumentParams) -> None: doc = server.workspace.get_text_document(params.text_document.uri) ls.parse(doc) diags = ls.diagnose(doc.uri) ls.publish_diagnostics(doc.uri, diags) @server.feature(TEXT_DOCUMENT_DID_CHANGE) def on_change(ls: SkillLanguageServer, params: DidChangeTextDocumentParams) -> None: ls.update(params.text_document.uri, changes=params.content_changes) diags = ls.diagnose(params.text_document.uri) ls.publish_diagnostics(params.text_document.uri, diags) @server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL) def doc_symbols( ls: SkillLanguageServer, params: DocumentSymbolParams, ) -> list[DocumentSymbol]: # return ls.procs + ls.lets + ls.defs + ls.globals return [] def main(): server.start_io()