skill-ls/skillls/main.py

448 lines
16 KiB
Python

from collections.abc import Generator
from dataclasses import dataclass, field
from logging import INFO, basicConfig, debug, error, getLogger, info, warning
from re import findall, finditer, fullmatch, match as rematch
from time import time
from lsprotocol.types import (
INLAY_HINT_RESOLVE,
TEXT_DOCUMENT_DID_CHANGE,
TEXT_DOCUMENT_DID_OPEN,
TEXT_DOCUMENT_DID_SAVE,
TEXT_DOCUMENT_DOCUMENT_SYMBOL,
TEXT_DOCUMENT_HOVER,
TEXT_DOCUMENT_INLAY_HINT,
WORKSPACE_INLAY_HINT_REFRESH,
WORKSPACE_SEMANTIC_TOKENS_REFRESH,
CompletionItem,
Diagnostic,
DiagnosticSeverity,
DidChangeTextDocumentParams,
DidOpenTextDocumentParams,
DidSaveTextDocumentParams,
DocumentSymbol,
DocumentSymbolParams,
Hover,
HoverParams,
InlayHint,
InlayHintKind,
InlayHintParams,
MessageType,
Position,
Range,
SymbolKind,
)
from pygls.server import LanguageServer
from pygls.workspace import TextDocument
from skillls.parsing.iterative import IterativeParser, TokenParser
from .cache import Cache
URI = str
basicConfig(filename="skillls.log", filemode="w", level=INFO)
cache: Cache[str, CompletionItem] = Cache()
def in_range(what: Position, area: Range) -> bool:
return (what >= area.start) and (what <= area.end)
def find_end(start: Position, lines: list[str]) -> Position:
count = 0
in_str: bool = False
last = ""
for row, line in enumerate(lines[start.line :]):
if row == 0:
line = line[start.character :]
row += start.character
for col, char in enumerate(line[start.character :] if row == 0 else line):
match char:
case "(":
if not in_str:
count += 1
case ")":
if not in_str:
if count > 0:
count -= 1
if count == 0:
return Position(start.line + row, col)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
error(f"did not fin end for start at {start}")
return Position(len(lines), len(lines[-1]))
@dataclass(frozen=True)
class Environment:
range: Range
@dataclass(frozen=True)
class LetEnvironment(Environment):
locals: set[str] = field(default_factory=set)
#
# @dataclass(frozen=True)
# class ProcEnvironment(Environment):
# name: str
# args: tuple[DocumentSymbol, ...]
# kwargs: tuple[DocumentSymbol, ...]
# rest: DocumentSymbol | None = None
#
# @property
# def locals(self) -> tuple[DocumentSymbol, ...]:
# ret = [*self.args, *self.kwargs]
# if self.rest:
# ret.append(self.rest)
#
# return tuple(ret)
class SkillLanguageServer(LanguageServer):
lets: list[DocumentSymbol] = []
procs: list[DocumentSymbol] = []
defs: list[DocumentSymbol] = []
globals: list[DocumentSymbol] = []
@property
def envs(self) -> tuple[DocumentSymbol, ...]:
return (
*self.procs,
*self.lets,
)
def _diagnose_parens(self, doc: TextDocument) -> Generator[Diagnostic, None, None]:
open: list[tuple[int, int]] = []
in_str: bool = False
last = ""
for row, line in enumerate(doc.lines):
for col, char in enumerate(line):
match char:
case "(":
if not in_str:
open.append((row, col))
case ")":
if not in_str:
if len(open) > 0:
open.pop()
else:
yield (
Diagnostic(
Range(
Position(row, col),
Position(row, col),
),
"unopened ) encountered",
)
)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
if len(open) > 0:
for row, col in open:
yield (
Diagnostic(
Range(Position(row, col), Position(row, col)),
"unclosed ) encountered",
)
)
def _diagnose_cisms(self, doc: TextDocument) -> Generator[Diagnostic, None, None]:
for row, line in enumerate(doc.lines):
for m in finditer(
r"(?P<proc>procedure\s+|;.*)?([a-zA-Z_][a-zA-Z_0-9]+)\(", line
):
if not m.group("proc"):
yield Diagnostic(
Range(Position(row, m.start()), Position(row, m.end())),
f"change `{m.group(2)}(` to `( {m.group(2)}`",
DiagnosticSeverity.Hint,
)
def diagnose(self, doc: TextDocument) -> None:
diags: list[Diagnostic] = []
diags.extend(self._diagnose_parens(doc))
diags.extend(self._diagnose_cisms(doc))
self.publish_diagnostics(doc.uri, diags)
def parse(self, doc: TextDocument) -> None:
self.lets = []
self._parse_let(doc.lines)
self.procs = []
self._parse_proc(doc.lines, doc.uri)
self.globals = []
self._parse_assigns(doc.lines)
def _parse_assigns(self, lines: list[str]) -> None:
for row, line in enumerate(lines):
for found in finditer(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s+", line):
token = found.group(1)
token_range = Range(
Position(row, found.start()),
Position(row, found.start() + len(token)),
)
if any(
in_range(token_range.start, let.range)
and (token in (child.name for child in (let.children or [])))
for let in self.lets
):
pass
else:
self.globals.append(
DocumentSymbol(
token, SymbolKind.Variable, token_range, token_range
)
)
def _parse_let(self, lines: list[str]) -> None:
active_let: DocumentSymbol
for row, line in enumerate(lines):
for found in finditer(r"(\(\s*let\s+|\blet\(\s+)\((.*)\)", line):
start = Position(row, found.start())
end = find_end(start, lines)
children: list[DocumentSymbol] = []
active_let = DocumentSymbol(
"let",
SymbolKind.Namespace,
Range(start, end),
Range(start, end),
children=children,
)
self.lets.append(active_let)
offset = len(found.group(1)) + 3
for local_var in finditer(
r"([a-zA-Z_][a-zA-Z0-9_]*|\([a-zA-Z_][a-zA-Z0-9_]*\s+.+\))",
found.group(2),
):
if local_var.group(1).startswith("("):
if m := fullmatch(
r"\(([a-zA-Z_][a-zA-Z0-9_]*)\s+.+\)",
local_var.group(1),
):
children.append(
DocumentSymbol(
m.group(1),
SymbolKind.Variable,
Range(
Position(row, offset + local_var.start() + 1),
Position(
row,
offset
+ local_var.start()
+ 1
+ len(m.string),
),
),
Range(
Position(row, offset + local_var.start() + 1),
Position(
row,
offset
+ local_var.start()
+ 1
+ len(m.group(1)),
),
),
)
)
else:
assert isinstance(active_let.children, list)
active_let.children.append(
DocumentSymbol(
local_var.group(1),
SymbolKind.Variable,
Range(
Position(row, offset + local_var.start()),
Position(row, offset + local_var.end()),
),
Range(
Position(row, offset + local_var.start()),
Position(row, offset + local_var.end()),
),
)
)
def _parse_proc(self, lines: list[str], uri: str) -> None:
for row, line in enumerate(lines):
for found in finditer(
r"(\(\s*procedure|\bprocedure\()(\s+)([a-zA-Z_][a-zA-Z0-9_]*)\((.*)\)",
line,
):
start = Position(row, found.start())
end = find_end(start, lines)
if "@option" in found.group(4) and "@key" in found.group(4):
self.publish_diagnostics(
uri,
[
Diagnostic(
Range(start, Position(row, len(line))),
"`@key` and `@option` used in same definition",
severity=DiagnosticSeverity.Error,
)
],
)
return
args: list[DocumentSymbol] = []
kwargs: list[DocumentSymbol] = []
rest: list[DocumentSymbol] = []
params_start = found.end() - len(found.group(4))
for part in finditer(
r"(@(option|key)(\s\(\w+\s+.+\))+|@rest \w+|(\w+\s*))",
found.group(4),
):
if part.group(1).startswith("@rest"):
rest_var_name = part.group(1).split()[1]
rest_var_range = Range(
Position(
row,
params_start + part.end() - len(rest_var_name),
),
Position(row, params_start + part.end()),
)
rest.append(
DocumentSymbol(
rest_var_name,
kind=SymbolKind.Variable,
range=rest_var_range,
selection_range=rest_var_range,
)
)
elif part.group(1).startswith("@"):
for kwarg in finditer(r"(\((\w+)\s+[^\)]+\))", part.group(1)):
kwargs.append(
DocumentSymbol(
kwarg.group(2),
kind=SymbolKind.Variable,
range=Range(
Position(
row,
params_start + part.start() + kwarg.start(),
),
Position(
row,
params_start + part.start() + kwarg.end(),
),
),
selection_range=Range(
Position(
row,
params_start + part.start() + kwarg.start(),
),
Position(
row,
params_start
+ part.start()
+ kwarg.start()
+ len(kwarg.group(2)),
),
),
)
)
else:
for arg in finditer(r"(\w+)", part.group(1)):
arg_range = Range(
Position(
row,
params_start + part.start() + arg.start() - 1,
),
Position(
row,
params_start + part.start() + arg.end() - 1,
),
)
args.append(
DocumentSymbol(
arg.group(1),
kind=SymbolKind.Variable,
range=arg_range,
selection_range=arg_range,
)
)
self.procs.append(
DocumentSymbol(
found.group(3),
kind=SymbolKind.Function,
range=Range(start, end),
selection_range=Range(start, Position(row, len(line))),
children=args + rest + kwargs,
)
)
def _hint_let(self) -> Generator[InlayHint, None, None]:
for let in self.lets:
if let.children:
for child in let.children:
yield InlayHint(child.selection_range.end, "|l")
def _hint_proc(self) -> Generator[InlayHint, None, None]:
for proc in self.procs:
warning(proc)
if proc.children:
for child in proc.children:
yield InlayHint(child.selection_range.end, "|l")
def _hint_globals(self) -> Generator[InlayHint, None, None]:
for glbl in self.globals:
yield InlayHint(glbl.selection_range.end, "|g")
def hint(self, doc: TextDocument, area: Range) -> list[InlayHint]:
hints: list[InlayHint] = []
hints.extend(self._hint_proc())
hints.extend(self._hint_let())
hints.extend(self._hint_globals())
return hints
server = SkillLanguageServer("skillls", "v0.3")
@server.feature(TEXT_DOCUMENT_DID_SAVE)
@server.feature(TEXT_DOCUMENT_DID_OPEN)
@server.feature(TEXT_DOCUMENT_DID_CHANGE)
def on_open(ls: SkillLanguageServer, params: DidSaveTextDocumentParams) -> None:
doc = server.workspace.get_text_document(params.text_document.uri)
if not ls.diagnose(doc):
ls.parse(doc)
ls.lsp.send_request_async(WORKSPACE_INLAY_HINT_REFRESH)
@server.feature(TEXT_DOCUMENT_INLAY_HINT)
def inlay_hints(ls: SkillLanguageServer, params: InlayHintParams) -> list[InlayHint]:
doc = server.workspace.get_text_document(params.text_document.uri)
return ls.hint(doc, params.range)
@server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL)
def doc_symbols(
ls: SkillLanguageServer,
params: DocumentSymbolParams,
) -> list[DocumentSymbol]:
return ls.procs + ls.lets + ls.defs + ls.globals
def main():
server.start_io()