277 lines
7.8 KiB
Python
277 lines
7.8 KiB
Python
from collections.abc import Callable, Generator, Sequence
|
|
from dataclasses import dataclass, field
|
|
from difflib import Differ
|
|
from itertools import chain
|
|
from logging import DEBUG, INFO, basicConfig, debug, error, getLogger, info, warning
|
|
from re import findall, finditer, fullmatch, match as rematch
|
|
import re
|
|
from time import time
|
|
from cattrs import Converter
|
|
from lsprotocol.types import (
|
|
INLAY_HINT_RESOLVE,
|
|
TEXT_DOCUMENT_DID_CHANGE,
|
|
TEXT_DOCUMENT_DID_OPEN,
|
|
TEXT_DOCUMENT_DID_SAVE,
|
|
TEXT_DOCUMENT_DOCUMENT_SYMBOL,
|
|
TEXT_DOCUMENT_HOVER,
|
|
TEXT_DOCUMENT_INLAY_HINT,
|
|
WORKSPACE_INLAY_HINT_REFRESH,
|
|
WORKSPACE_SEMANTIC_TOKENS_REFRESH,
|
|
CompletionItem,
|
|
Diagnostic,
|
|
DiagnosticSeverity,
|
|
DidChangeTextDocumentParams,
|
|
DidOpenTextDocumentParams,
|
|
DidSaveTextDocumentParams,
|
|
DocumentSymbol,
|
|
DocumentSymbolParams,
|
|
Hover,
|
|
HoverParams,
|
|
InlayHint,
|
|
InlayHintKind,
|
|
InlayHintParams,
|
|
MessageType,
|
|
NotebookDocumentSyncOptions,
|
|
Position,
|
|
Range,
|
|
SymbolKind,
|
|
TextDocumentContentChangeEvent,
|
|
TextDocumentContentChangeEvent_Type1,
|
|
TextDocumentSyncKind,
|
|
)
|
|
|
|
from pygls.protocol import LanguageServerProtocol, default_converter
|
|
from tree_sitter_skill import language as skill_lang
|
|
from tree_sitter import Language, Node, Parser, Query, Tree
|
|
|
|
from pygls.server import LanguageServer
|
|
from pygls.workspace import TextDocument
|
|
|
|
from .cache import Cache
|
|
|
|
SKILL_LANG = Language(skill_lang())
|
|
SKILL_PARSER = Parser(SKILL_LANG)
|
|
|
|
|
|
URI = str
|
|
|
|
basicConfig(
|
|
filename="skillls.log",
|
|
filemode="w",
|
|
level=DEBUG,
|
|
format="%(asctime)s [%(levelname)s]: %(message)s",
|
|
)
|
|
logger = getLogger()
|
|
cache: Cache[str, CompletionItem] = Cache()
|
|
|
|
|
|
def in_range(what: Position, area: Range) -> bool:
|
|
return (what >= area.start) and (what <= area.end)
|
|
|
|
|
|
def find_end(start: Position, lines: list[str]) -> Position:
|
|
count = 0
|
|
in_str: bool = False
|
|
last = ""
|
|
for row, line in enumerate(lines[start.line :]):
|
|
if row == 0:
|
|
line = line[start.character :]
|
|
row += start.character
|
|
for col, char in enumerate(line[start.character :] if row == 0 else line):
|
|
match char:
|
|
case "(":
|
|
if not in_str:
|
|
count += 1
|
|
case ")":
|
|
if not in_str:
|
|
if count > 0:
|
|
count -= 1
|
|
if count == 0:
|
|
return Position(start.line + row, col)
|
|
case '"':
|
|
if not (in_str and last == "\\"):
|
|
in_str = not in_str
|
|
case _:
|
|
last = char
|
|
|
|
last = char
|
|
|
|
error(f"did not fin end for start at {start}")
|
|
return Position(len(lines), len(lines[-1]))
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Environment:
|
|
range: Range
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LetEnvironment(Environment):
|
|
locals: set[str] = field(default_factory=set)
|
|
|
|
|
|
def offset_range(range: Range, lines: int, cols: int = 0) -> Range:
|
|
return Range(
|
|
Position(
|
|
range.start.line + lines,
|
|
range.start.character + cols,
|
|
),
|
|
Position(
|
|
range.end.line + lines,
|
|
range.end.character + cols,
|
|
),
|
|
)
|
|
|
|
|
|
#
|
|
# @dataclass(frozen=True)
|
|
# class ProcEnvironment(Environment):
|
|
# name: str
|
|
# args: tuple[DocumentSymbol, ...]
|
|
# kwargs: tuple[DocumentSymbol, ...]
|
|
# rest: DocumentSymbol | None = None
|
|
#
|
|
# @property
|
|
# def locals(self) -> tuple[DocumentSymbol, ...]:
|
|
# ret = [*self.args, *self.kwargs]
|
|
# if self.rest:
|
|
# ret.append(self.rest)
|
|
#
|
|
# return tuple(ret)
|
|
|
|
|
|
class SkillLanguageServer(LanguageServer):
|
|
contents: dict[str, TextDocument]
|
|
trees: dict[str, Tree]
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
version: str,
|
|
loop=None,
|
|
protocol_cls: type[LanguageServerProtocol] = LanguageServerProtocol,
|
|
converter_factory: Callable[[], Converter] = default_converter,
|
|
text_document_sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental,
|
|
notebook_document_sync: NotebookDocumentSyncOptions | None = None,
|
|
max_workers: int = 2,
|
|
):
|
|
super().__init__(
|
|
name,
|
|
version,
|
|
loop,
|
|
protocol_cls,
|
|
converter_factory,
|
|
text_document_sync_kind,
|
|
notebook_document_sync,
|
|
max_workers,
|
|
)
|
|
self.trees = {}
|
|
self.contents = {}
|
|
|
|
def parse(self, doc: TextDocument) -> None:
|
|
parsed = SKILL_PARSER.parse(doc.source.encode("utf8"), encoding="utf8")
|
|
self.trees[doc.uri] = parsed
|
|
self.contents[doc.uri] = doc
|
|
|
|
def update(self, uri: str, changes: list[TextDocumentContentChangeEvent]) -> None:
|
|
for change in changes:
|
|
if isinstance(change, TextDocumentContentChangeEvent_Type1):
|
|
logger.debug(f"updating {change.range}")
|
|
change_fixed = TextDocumentContentChangeEvent_Type1(
|
|
offset_range(change.range, -1),
|
|
change.text,
|
|
change.range_length,
|
|
)
|
|
old = self.contents[uri].lines
|
|
self.contents[uri].apply_change(change)
|
|
d = Differ()
|
|
logger.debug("".join(d.compare(old, self.contents[uri].lines)))
|
|
|
|
else:
|
|
pass
|
|
|
|
self.trees[uri] = SKILL_PARSER.parse(
|
|
self.contents[uri].source.encode("utf8"),
|
|
old_tree=self.trees[uri],
|
|
)
|
|
|
|
def _get_leaves(self, node: Node) -> list[Node]:
|
|
if node.children:
|
|
return [l for child in node.children for l in self._get_leaves(child)]
|
|
|
|
return [node]
|
|
|
|
def _diagnose_errors(self, uri: str) -> list[Diagnostic]:
|
|
diags: list[Diagnostic] = []
|
|
q = SKILL_LANG.query("(ERROR) @error")
|
|
nodes = (
|
|
q.captures(self.trees[uri].root_node)["error"]
|
|
if self.trees.get(uri)
|
|
else []
|
|
)
|
|
|
|
for node in nodes:
|
|
if node.type == "ERROR":
|
|
logger.error(node)
|
|
logger.error(node.range)
|
|
content = node.text.decode("utf8") if node.text else ""
|
|
range = Range(
|
|
Position(*node.range.start_point), Position(*node.range.end_point)
|
|
)
|
|
|
|
if "UNEXPECTED" in str(node):
|
|
msg = f"unexpected '{content}'"
|
|
else:
|
|
msg = str()
|
|
|
|
diags.append(
|
|
Diagnostic(
|
|
range,
|
|
msg,
|
|
severity=DiagnosticSeverity.Error,
|
|
),
|
|
)
|
|
|
|
return diags
|
|
|
|
def diagnose(self, uri: str) -> list[Diagnostic]:
|
|
diags: list[Diagnostic] = []
|
|
|
|
diags.extend(self._diagnose_errors(uri))
|
|
|
|
return diags
|
|
|
|
|
|
server = SkillLanguageServer("skillls", "v0.3")
|
|
|
|
|
|
# @server.feature(TEXT_DOCUMENT_DID_SAVE)
|
|
@server.feature(TEXT_DOCUMENT_DID_OPEN)
|
|
def on_open(ls: SkillLanguageServer, params: DidSaveTextDocumentParams) -> None:
|
|
doc = server.workspace.get_text_document(params.text_document.uri)
|
|
ls.parse(doc)
|
|
|
|
diags = ls.diagnose(doc.uri)
|
|
ls.publish_diagnostics(doc.uri, diags)
|
|
|
|
|
|
@server.feature(TEXT_DOCUMENT_DID_CHANGE)
|
|
def on_change(ls: SkillLanguageServer, params: DidChangeTextDocumentParams) -> None:
|
|
ls.update(params.text_document.uri, changes=params.content_changes)
|
|
|
|
diags = ls.diagnose(params.text_document.uri)
|
|
ls.publish_diagnostics(params.text_document.uri, diags)
|
|
|
|
|
|
@server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL)
|
|
def doc_symbols(
|
|
ls: SkillLanguageServer,
|
|
params: DocumentSymbolParams,
|
|
) -> list[DocumentSymbol]:
|
|
# return ls.procs + ls.lets + ls.defs + ls.globals
|
|
return []
|
|
|
|
|
|
def main():
|
|
server.start_io()
|