This commit is contained in:
2025-11-22 17:56:49 +01:00
parent 82b165dd21
commit 8730493857
7 changed files with 575 additions and 242 deletions
+75
View File
@@ -0,0 +1,75 @@
from dataclasses import dataclass
from enum import Enum, auto
from lsprotocol.types import Location, Position, Range
from skillls.types import URI
class SyntaxError(Exception):
pass
class ParenMismatchErrorKind(Enum):
TooManyClosed = "Found too many closing parens"
TooManyOpened = "Found too many open parens"
@dataclass
class ParenMismatchError(SyntaxError):
kind: ParenMismatchErrorKind
loc: Range
def _check_for_matching_parens(content: str) -> list[Exception]:
excs: list[Exception] = []
opened = 0
line = 0
col = 0
last_open: Position = Position(0, 0)
last_close: Position = Position(0, 0)
for char in content:
match char:
case "(":
opened += 1
last_open = Position(line, col)
case ")":
opened -= 1
if opened < 0:
excs.append(
ParenMismatchError(
ParenMismatchErrorKind.TooManyClosed,
Range(Position(line, col), Position(line, col + 1)),
)
)
opened = 0
last_close = Position(line, col)
case "\n":
line += 1
col = -1
case _:
pass
col += 1
if opened > 0:
excs.append(
ParenMismatchError(
ParenMismatchErrorKind.TooManyOpened,
Range(last_open, Position(last_open.line, last_open.character + 1)),
)
)
return excs
def check_content_for_errors(clean_content: str) -> None:
excs: list[Exception] = []
excs.extend(_check_for_matching_parens(clean_content))
if excs:
raise ExceptionGroup("", excs)
+192
View File
@@ -0,0 +1,192 @@
from copy import copy
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
from pprint import pformat
from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
from re import MULTILINE, compile as recompile, finditer
from pygls.workspace import TextDocument
from skillls.checker import check_content_for_errors
from skillls.types import URI, Node, NodeKind
logger = getLogger(__name__)
@dataclass
class ParserCleanerState:
in_comment: bool = False
in_string: bool = False
NODE_KIND_OPTIONS = "|".join(k.value for k in NodeKind)
NAMESPACE_STARTERS = recompile(
(rf"(\(\s*(?P<typ>{NODE_KIND_OPTIONS})\b|\b(?P<ctyp>{NODE_KIND_OPTIONS})\()"),
MULTILINE,
)
def clean_content(content: str) -> str:
content_cleaned = ""
state = ParserCleanerState()
for cix, char in enumerate(content):
match (content[cix], state):
case ";", ParserCleanerState(in_comment=False, in_string=False):
state.in_comment = True
case '"', ParserCleanerState(in_comment=False):
if content[cix - 1] != "\\":
state.in_string = not state.in_string
content_cleaned += char
case "\n", ParserCleanerState(in_comment=True):
state.in_comment = False
content_cleaned += char
case _, ParserCleanerState(in_comment=False, in_string=False):
content_cleaned += char
case _, ParserCleanerState(in_comment=False, in_string=True):
content_cleaned += " "
case _:
pass
return content_cleaned
def build_node_hierarchy(nodes: list[Node]) -> list[Node]:
to_be_sorted = copy(nodes)
sorted: list[Node] = []
while to_be_sorted:
node_to_sort = to_be_sorted.pop(0)
for sorted_node in sorted:
if sorted_node.should_contain(node_to_sort):
sorted_node.add_child(node_to_sort)
break
else:
sorted.append(node_to_sort)
return sorted
def find_scopes(content_cleaned: str, scope_prefix: str = "") -> list[Node]:
ret: list[Node] = []
for found in NAMESPACE_STARTERS.finditer(content_cleaned):
partial = content_cleaned[found.end() :]
open_brackets = 1
offset = 0
for offset, char in enumerate(partial):
match char:
case "(":
open_brackets += 1
case ")":
open_brackets -= 1
if open_brackets == 0:
break
case _:
pass
pre_lines = content_cleaned[: found.start()].splitlines()
start_line = len(pre_lines) - (
1 if pre_lines[-1] != "" and pre_lines[-1].strip() == "" else 0
)
start_char = len(pre_lines[-1])
inner_lines = content_cleaned[
found.start() : found.end() + offset + 1
].splitlines()
end_line = start_line + len(inner_lines) - 1
end_char = len(inner_lines[-1])
kind = NodeKind(found.group("typ") or found.group("ctyp"))
loc = Range(Position(start_line, start_char), Position(end_line, end_char))
node = Node(
node=f"{scope_prefix}.{kind.value}_{len([n for n in ret if n.kind == kind])}",
kind=kind,
location=loc,
)
ret.append(node)
next = found.end()
# allowed scoped locals syntax
# function(pos1 pos2)
# function(pos1 (pos2 default))
# function(pos1 @rest args)
# function(pos1 @key (kwarg1 default1) (kwarg2 default2))
while content_cleaned[next] != "(":
if content_cleaned[next] == "\n":
start_line += 1
start_char = 0
next += 1
start_char += 1
next += 1
last = 0
for positional in finditer(
r"(?P<leading>\s*)(?P<local>\w+|\(\w+\b[^)]*\))(?P<trailing>\s*)",
content_cleaned[next:],
):
if positional.start() != last:
logger.debug(
f"found ({positional}), but last ({last}) != ({positional.start()})"
)
break
last = positional.end()
leading_nls = positional.group("leading").count("\n")
inner_nls = positional.group("local").count("\n")
trailing_nls = positional.group("trailing").count("\n")
local_name = positional.group("local").split()[0]
local = DocumentSymbol(
name=local_name,
kind=SymbolKind.Variable,
range=Range(
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char,
),
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char + len(local_name),
),
),
selection_range=Range(
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char,
),
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char + len(local_name),
),
),
)
node.symbols[local_name] = local
start_line += leading_nls + inner_nls + trailing_nls
start_char += len(positional.group(0))
# other cases
logger.debug(pformat(node))
return build_node_hierarchy(ret)
def parse_file(file: TextDocument) -> list[Node]:
content = file.source
content_cleaned = clean_content(content)
check_content_for_errors(content_cleaned)
return find_scopes(content_cleaned, scope_prefix=Path(file.path).stem)
+127 -107
View File
@@ -1,59 +1,41 @@
from collections.abc import Callable, Generator, Sequence
from collections.abc import Callable
from dataclasses import dataclass, field
from difflib import Differ
from itertools import chain
from logging import DEBUG, INFO, basicConfig, debug, error, getLogger, info, warning
from re import findall, finditer, fullmatch, match as rematch
import re
from time import time
from logging import DEBUG, basicConfig, getLogger
from pathlib import Path
from typing import Any
from cattrs import Converter
from lsprotocol.types import (
INLAY_HINT_RESOLVE,
TEXT_DOCUMENT_DID_CHANGE,
TEXT_DOCUMENT_DID_CLOSE,
TEXT_DOCUMENT_DID_OPEN,
TEXT_DOCUMENT_DID_SAVE,
INITIALIZE,
TEXT_DOCUMENT_DOCUMENT_SYMBOL,
TEXT_DOCUMENT_HOVER,
TEXT_DOCUMENT_INLAY_HINT,
WORKSPACE_INLAY_HINT_REFRESH,
WORKSPACE_SEMANTIC_TOKENS_REFRESH,
CompletionItem,
Diagnostic,
DiagnosticSeverity,
DidChangeTextDocumentParams,
DidCloseTextDocumentParams,
DidOpenTextDocumentParams,
DidSaveTextDocumentParams,
DocumentSymbol,
DocumentSymbolParams,
Hover,
HoverParams,
InitializeParams,
InlayHint,
InlayHintKind,
InlayHintParams,
MessageType,
NotebookDocumentSyncOptions,
Position,
Range,
SymbolKind,
TextDocumentContentChangeEvent,
TextDocumentContentChangeEvent_Type1,
PublishDiagnosticsNotification,
PublishDiagnosticsParams,
ShowMessageParams,
TextDocumentSyncKind,
)
from pygls.lsp.server import LanguageServer
from pygls.protocol import LanguageServerProtocol, default_converter
from tree_sitter_skill import language as skill_lang
from tree_sitter import Language, Node, Parser, Query, Tree
from pygls.server import LanguageServer
from pygls.workspace import TextDocument
from .cache import Cache
SKILL_LANG = Language(skill_lang())
SKILL_PARSER = Parser(SKILL_LANG)
URI = str
from skillls.checker import ParenMismatchError, ParenMismatchErrorKind
from skillls.helpers import parse_file
from skillls.types import URI, Node
basicConfig(
filename="skillls.log",
@@ -61,92 +43,130 @@ basicConfig(
level=DEBUG,
format="%(asctime)s [%(levelname)s]: %(message)s",
)
logger = getLogger()
cache: Cache[str, CompletionItem] = Cache()
def in_range(what: Position, area: Range) -> bool:
return (what >= area.start) and (what <= area.end)
def find_end(start: Position, lines: list[str]) -> Position:
count = 0
in_str: bool = False
last = ""
for row, line in enumerate(lines[start.line :]):
if row == 0:
line = line[start.character :]
row += start.character
for col, char in enumerate(line[start.character :] if row == 0 else line):
match char:
case "(":
if not in_str:
count += 1
case ")":
if not in_str:
if count > 0:
count -= 1
if count == 0:
return Position(start.line + row, col)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
error(f"did not fin end for start at {start}")
return Position(len(lines), len(lines[-1]))
@dataclass(frozen=True)
class Environment:
range: Range
@dataclass(frozen=True)
class LetEnvironment(Environment):
locals: set[str] = field(default_factory=set)
def offset_range(range: Range, lines: int, cols: int = 0) -> Range:
return Range(
Position(
range.start.line + lines,
range.start.character + cols,
),
Position(
range.end.line + lines,
range.end.character + cols,
),
)
logger = getLogger(__name__)
class SkillLanguageServer(LanguageServer):
ws_files: set[URI]
opened_files: set[URI]
scopes: dict[URI, list[Node]]
errs: dict[URI, ExceptionGroup]
def __init__(
self,
name: str,
version: str,
loop=None,
protocol_cls: type[LanguageServerProtocol] = LanguageServerProtocol,
converter_factory: Callable[[], Converter] = default_converter,
text_document_sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental,
notebook_document_sync: NotebookDocumentSyncOptions | None = None,
max_workers: int = 2,
):
super().__init__(
name,
version,
loop,
protocol_cls,
converter_factory,
text_document_sync_kind,
notebook_document_sync,
max_workers,
)
super().__init__(name, version, text_document_sync_kind, notebook_document_sync)
self.ws_files = set()
self.opened_files = set()
self.scopes = {}
self.errs = {}
def update_diagnostics(self) -> None:
for uri in self.opened_files:
diags: list[Diagnostic] = []
if eg := self.errs.get(uri):
for exc in eg.exceptions:
match exc:
case ParenMismatchError():
diags.append(
Diagnostic(
message=f"[skill_ls] {Path.from_uri(uri).name}:{exc.loc.start.line} {exc.kind.value}",
severity=DiagnosticSeverity.Error,
range=exc.loc,
)
)
# if diags:
self.text_document_publish_diagnostics(
PublishDiagnosticsParams(
uri=uri,
version=self.workspace.get_text_document(uri).version,
diagnostics=diags,
)
)
server = SkillLanguageServer("SkillLS", "0.2.0")
@server.feature(INITIALIZE)
def lsp_initialize(server: SkillLanguageServer, params: InitializeParams) -> None:
init_options: dict[str, Any] = params.initialization_options or {}
logger.info("done init")
logger.debug(init_options)
ws_dir = server.workspace.root_path
logger.debug(ws_dir)
if ws_dir:
root_dir = Path(ws_dir)
for file in (*root_dir.rglob("*.il"), *root_dir.rglob("*.ocn")):
uri = file.as_uri()
logger.debug(uri)
server.ws_files.add(uri)
try:
server.scopes[uri] = parse_file(server.workspace.get_text_document(uri))
if server.errs.get(uri):
del server.errs[uri]
except ExceptionGroup as eg:
server.errs[uri] = eg
@server.feature(TEXT_DOCUMENT_DID_OPEN)
def on_open(server: SkillLanguageServer, params: DidOpenTextDocumentParams) -> None:
server.opened_files.add(params.text_document.uri)
server.update_diagnostics()
@server.feature(TEXT_DOCUMENT_DID_CLOSE)
def on_close(server: SkillLanguageServer, params: DidCloseTextDocumentParams) -> None:
server.opened_files.remove(params.text_document.uri)
@server.feature(TEXT_DOCUMENT_DID_CHANGE)
def on_change(server: SkillLanguageServer, params: DidChangeTextDocumentParams) -> None:
try:
server.scopes[params.text_document.uri] = parse_file(
server.workspace.get_text_document(params.text_document.uri)
)
if server.errs.get(params.text_document.uri):
del server.errs[params.text_document.uri]
except ExceptionGroup as eg:
server.errs[params.text_document.uri] = eg
server.update_diagnostics()
@server.feature(TEXT_DOCUMENT_INLAY_HINT)
def on_inlay(server: SkillLanguageServer, params: InlayHintParams) -> list[InlayHint]:
hints: list[InlayHint] = []
for uri in server.opened_files:
for node in server.scopes.get(uri, []):
hints.append(
InlayHint(
label=node.node,
kind=InlayHintKind.Type,
padding_left=True,
position=node.location.end,
)
)
return hints
@server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL)
def on_symbols(
server: SkillLanguageServer, params: DocumentSymbolParams
) -> list[DocumentSymbol] | None:
return [node.as_doc_symbol() for node in server.scopes[params.text_document.uri]]
def main():
+66
View File
@@ -0,0 +1,66 @@
from dataclasses import dataclass, field
from enum import Enum, auto
from lsprotocol.types import DocumentSymbol, Range, SymbolKind
URI = str
class NodeKind(Enum):
LET = "let"
PROCEDURE = "procedure"
PROC = "proc"
FOREACH = "foreach"
@dataclass
class Node:
node: str
kind: NodeKind
location: Range
children: list["Node"] = field(default_factory=list)
symbols: dict[str, DocumentSymbol] = field(default_factory=dict)
@property
def all_symbols(self) -> list[DocumentSymbol]:
return [
*self.symbols.values(),
*(sym for child in self.children for sym in child.all_symbols),
]
def should_contain(self, other: "Node") -> bool:
"""range based overlap check"""
start_after = (other.location.start.line > self.location.start.line) or (
(other.location.start.line == self.location.start.line)
and (other.location.start.character > self.location.start.character)
)
ends_before = (other.location.end.line < self.location.end.line) or (
(other.location.end.line == self.location.end.line)
and (other.location.end.character < self.location.start.character)
)
return start_after and ends_before
def add_child(self, new_child: "Node") -> None:
for existing_child in self.children:
if existing_child.should_contain(new_child):
existing_child.add_child(new_child)
break
else:
self.children.append(new_child)
def as_doc_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
name=self.node,
kind=SymbolKind.Namespace,
range=self.location,
selection_range=self.location,
children=list(self.symbols.values())
+ [child.as_doc_symbol() for child in self.children],
)
@dataclass
class DocumentSymbols:
uri: str
tree: Node