skill-ls/skillls/main.py

277 lines
7.8 KiB
Python

from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass, field
from difflib import Differ
from itertools import chain
from logging import DEBUG, INFO, basicConfig, debug, error, getLogger, info, warning
from re import findall, finditer, fullmatch, match as rematch
import re
from time import time
from cattrs import Converter
from lsprotocol.types import (
INLAY_HINT_RESOLVE,
TEXT_DOCUMENT_DID_CHANGE,
TEXT_DOCUMENT_DID_OPEN,
TEXT_DOCUMENT_DID_SAVE,
TEXT_DOCUMENT_DOCUMENT_SYMBOL,
TEXT_DOCUMENT_HOVER,
TEXT_DOCUMENT_INLAY_HINT,
WORKSPACE_INLAY_HINT_REFRESH,
WORKSPACE_SEMANTIC_TOKENS_REFRESH,
CompletionItem,
Diagnostic,
DiagnosticSeverity,
DidChangeTextDocumentParams,
DidOpenTextDocumentParams,
DidSaveTextDocumentParams,
DocumentSymbol,
DocumentSymbolParams,
Hover,
HoverParams,
InlayHint,
InlayHintKind,
InlayHintParams,
MessageType,
NotebookDocumentSyncOptions,
Position,
Range,
SymbolKind,
TextDocumentContentChangeEvent,
TextDocumentContentChangeEvent_Type1,
TextDocumentSyncKind,
)
from pygls.protocol import LanguageServerProtocol, default_converter
from tree_sitter_skill import language as skill_lang
from tree_sitter import Language, Node, Parser, Query, Tree
from pygls.server import LanguageServer
from pygls.workspace import TextDocument
from .cache import Cache
SKILL_LANG = Language(skill_lang())
SKILL_PARSER = Parser(SKILL_LANG)
URI = str
basicConfig(
filename="skillls.log",
filemode="w",
level=DEBUG,
format="%(asctime)s [%(levelname)s]: %(message)s",
)
logger = getLogger()
cache: Cache[str, CompletionItem] = Cache()
def in_range(what: Position, area: Range) -> bool:
return (what >= area.start) and (what <= area.end)
def find_end(start: Position, lines: list[str]) -> Position:
count = 0
in_str: bool = False
last = ""
for row, line in enumerate(lines[start.line :]):
if row == 0:
line = line[start.character :]
row += start.character
for col, char in enumerate(line[start.character :] if row == 0 else line):
match char:
case "(":
if not in_str:
count += 1
case ")":
if not in_str:
if count > 0:
count -= 1
if count == 0:
return Position(start.line + row, col)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
error(f"did not fin end for start at {start}")
return Position(len(lines), len(lines[-1]))
@dataclass(frozen=True)
class Environment:
range: Range
@dataclass(frozen=True)
class LetEnvironment(Environment):
locals: set[str] = field(default_factory=set)
def offset_range(range: Range, lines: int, cols: int = 0) -> Range:
return Range(
Position(
range.start.line + lines,
range.start.character + cols,
),
Position(
range.end.line + lines,
range.end.character + cols,
),
)
#
# @dataclass(frozen=True)
# class ProcEnvironment(Environment):
# name: str
# args: tuple[DocumentSymbol, ...]
# kwargs: tuple[DocumentSymbol, ...]
# rest: DocumentSymbol | None = None
#
# @property
# def locals(self) -> tuple[DocumentSymbol, ...]:
# ret = [*self.args, *self.kwargs]
# if self.rest:
# ret.append(self.rest)
#
# return tuple(ret)
class SkillLanguageServer(LanguageServer):
contents: dict[str, TextDocument]
trees: dict[str, Tree]
def __init__(
self,
name: str,
version: str,
loop=None,
protocol_cls: type[LanguageServerProtocol] = LanguageServerProtocol,
converter_factory: Callable[[], Converter] = default_converter,
text_document_sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental,
notebook_document_sync: NotebookDocumentSyncOptions | None = None,
max_workers: int = 2,
):
super().__init__(
name,
version,
loop,
protocol_cls,
converter_factory,
text_document_sync_kind,
notebook_document_sync,
max_workers,
)
self.trees = {}
self.contents = {}
def parse(self, doc: TextDocument) -> None:
parsed = SKILL_PARSER.parse(doc.source.encode("utf8"), encoding="utf8")
self.trees[doc.uri] = parsed
self.contents[doc.uri] = doc
def update(self, uri: str, changes: list[TextDocumentContentChangeEvent]) -> None:
for change in changes:
if isinstance(change, TextDocumentContentChangeEvent_Type1):
logger.debug(f"updating {change.range}")
change_fixed = TextDocumentContentChangeEvent_Type1(
offset_range(change.range, -1),
change.text,
change.range_length,
)
old = self.contents[uri].lines
self.contents[uri].apply_change(change)
d = Differ()
logger.debug("".join(d.compare(old, self.contents[uri].lines)))
else:
pass
self.trees[uri] = SKILL_PARSER.parse(
self.contents[uri].source.encode("utf8"),
old_tree=self.trees[uri],
)
def _get_leaves(self, node: Node) -> list[Node]:
if node.children:
return [l for child in node.children for l in self._get_leaves(child)]
return [node]
def _diagnose_errors(self, uri: str) -> list[Diagnostic]:
diags: list[Diagnostic] = []
q = SKILL_LANG.query("(ERROR) @error")
nodes = (
q.captures(self.trees[uri].root_node)["error"]
if self.trees.get(uri)
else []
)
for node in nodes:
if node.type == "ERROR":
logger.error(node)
logger.error(node.range)
content = node.text.decode("utf8") if node.text else ""
range = Range(
Position(*node.range.start_point), Position(*node.range.end_point)
)
if "UNEXPECTED" in str(node):
msg = f"unexpected '{content}'"
else:
msg = str()
diags.append(
Diagnostic(
range,
msg,
severity=DiagnosticSeverity.Error,
),
)
return diags
def diagnose(self, uri: str) -> list[Diagnostic]:
diags: list[Diagnostic] = []
diags.extend(self._diagnose_errors(uri))
return diags
server = SkillLanguageServer("skillls", "v0.3")
# @server.feature(TEXT_DOCUMENT_DID_SAVE)
@server.feature(TEXT_DOCUMENT_DID_OPEN)
def on_open(ls: SkillLanguageServer, params: DidSaveTextDocumentParams) -> None:
doc = server.workspace.get_text_document(params.text_document.uri)
ls.parse(doc)
diags = ls.diagnose(doc.uri)
ls.publish_diagnostics(doc.uri, diags)
@server.feature(TEXT_DOCUMENT_DID_CHANGE)
def on_change(ls: SkillLanguageServer, params: DidChangeTextDocumentParams) -> None:
ls.update(params.text_document.uri, changes=params.content_changes)
diags = ls.diagnose(params.text_document.uri)
ls.publish_diagnostics(params.text_document.uri, diags)
@server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL)
def doc_symbols(
ls: SkillLanguageServer,
params: DocumentSymbolParams,
) -> list[DocumentSymbol]:
# return ls.procs + ls.lets + ls.defs + ls.globals
return []
def main():
server.start_io()