skill-ls/skillls/main.py

483 lines
18 KiB
Python

from collections.abc import Generator
from dataclasses import dataclass, field
from itertools import chain
from logging import INFO, basicConfig, debug, error, getLogger, info, warning
from re import findall, finditer, fullmatch, match as rematch
from time import time
from lsprotocol.types import (
INLAY_HINT_RESOLVE,
TEXT_DOCUMENT_DID_CHANGE,
TEXT_DOCUMENT_DID_OPEN,
TEXT_DOCUMENT_DID_SAVE,
TEXT_DOCUMENT_DOCUMENT_SYMBOL,
TEXT_DOCUMENT_HOVER,
TEXT_DOCUMENT_INLAY_HINT,
WORKSPACE_INLAY_HINT_REFRESH,
WORKSPACE_SEMANTIC_TOKENS_REFRESH,
CompletionItem,
Diagnostic,
DiagnosticSeverity,
DidChangeTextDocumentParams,
DidOpenTextDocumentParams,
DidSaveTextDocumentParams,
DocumentSymbol,
DocumentSymbolParams,
Hover,
HoverParams,
InlayHint,
InlayHintKind,
InlayHintParams,
MessageType,
Position,
Range,
SymbolKind,
)
from pygls.server import LanguageServer
from pygls.workspace import TextDocument
from skillls.builtins.common import SkillDataType
from skillls.parsing.iterative import IterativeParser, TokenParser
from .cache import Cache
URI = str
basicConfig(filename="skillls.log", filemode="w", level=INFO)
cache: Cache[str, CompletionItem] = Cache()
def in_range(what: Position, area: Range) -> bool:
return (what >= area.start) and (what <= area.end)
def find_end(start: Position, lines: list[str]) -> Position:
count = 0
in_str: bool = False
last = ""
for row, line in enumerate(lines[start.line :]):
if row == 0:
line = line[start.character :]
row += start.character
for col, char in enumerate(line[start.character :] if row == 0 else line):
match char:
case "(":
if not in_str:
count += 1
case ")":
if not in_str:
if count > 0:
count -= 1
if count == 0:
return Position(start.line + row, col)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
error(f"did not fin end for start at {start}")
return Position(len(lines), len(lines[-1]))
@dataclass(frozen=True)
class Environment:
range: Range
@dataclass(frozen=True)
class LetEnvironment(Environment):
locals: set[str] = field(default_factory=set)
#
# @dataclass(frozen=True)
# class ProcEnvironment(Environment):
# name: str
# args: tuple[DocumentSymbol, ...]
# kwargs: tuple[DocumentSymbol, ...]
# rest: DocumentSymbol | None = None
#
# @property
# def locals(self) -> tuple[DocumentSymbol, ...]:
# ret = [*self.args, *self.kwargs]
# if self.rest:
# ret.append(self.rest)
#
# return tuple(ret)
class SkillLanguageServer(LanguageServer):
lets: list[DocumentSymbol] = []
procs: list[DocumentSymbol] = []
defs: list[DocumentSymbol] = []
globals: list[DocumentSymbol] = []
@property
def envs(self) -> tuple[DocumentSymbol, ...]:
return (
*self.procs,
*self.lets,
)
def _diagnose_parens(self, doc: TextDocument) -> Generator[Diagnostic, None, None]:
open: list[tuple[int, int]] = []
in_str: bool = False
last = ""
for row, line in enumerate(doc.lines):
for col, char in enumerate(line):
match char:
case "(":
if not in_str:
open.append((row, col))
case ")":
if not in_str:
if len(open) > 0:
open.pop()
else:
yield (
Diagnostic(
Range(
Position(row, col),
Position(row, col),
),
"unopened ) encountered",
)
)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
if len(open) > 0:
for row, col in open:
yield (
Diagnostic(
Range(Position(row, col), Position(row, col)),
"unclosed ) encountered",
)
)
def _diagnose_cisms(self, doc: TextDocument) -> Generator[Diagnostic, None, None]:
for row, line in enumerate(doc.lines):
for m in finditer(
r"(?P<proc>procedure\s+|;.*)?([a-zA-Z_][a-zA-Z_0-9]+)\(", line
):
if not m.group("proc"):
yield Diagnostic(
Range(Position(row, m.start()), Position(row, m.end())),
f"change `{m.group(2)}(` to `( {m.group(2)}`",
DiagnosticSeverity.Hint,
)
def diagnose(self, doc: TextDocument) -> None:
diags: list[Diagnostic] = []
diags.extend(self._diagnose_parens(doc))
diags.extend(self._diagnose_cisms(doc))
self.publish_diagnostics(doc.uri, diags)
def parse(self, doc: TextDocument) -> None:
self.lets = []
self._parse_let(doc.lines)
self.procs = []
self._parse_proc(doc.lines, doc.uri)
self.globals = []
self._parse_assigns(doc.lines)
def _parse_assigns(self, lines: list[str]) -> None:
for row, line in enumerate(lines):
for found in finditer(
r"\b([a-zA-Z_][a-zA-Z0-9_]*)((-|~)>[a-zA-Z_][a-zA-Z0-9_]*)?\s*=\s+",
line,
):
token = found.group(1)
token_range = Range(
Position(row, found.start()),
Position(row, found.start() + len(token)),
)
if any(
in_range(token_range.start, ns.range)
and (token in (child.name for child in (ns.children or [])))
for ns in chain(self.lets, self.procs)
):
pass
else:
self.globals.append(
DocumentSymbol(
token, SymbolKind.Variable, token_range, token_range
)
)
def _parse_let(self, lines: list[str]) -> None:
active_let: DocumentSymbol
for row, line in enumerate(lines):
for found in finditer(r"(\(\s*let\s+|\blet\(\s+)\((.*)\)", line):
start = Position(row, found.start())
end = find_end(start, lines)
children: list[DocumentSymbol] = []
active_let = DocumentSymbol(
"let",
SymbolKind.Namespace,
Range(start, end),
Range(start, end),
children=children,
)
self.lets.append(active_let)
offset = len(found.group(1)) + 3
for local_var in finditer(
r"([a-zA-Z_][a-zA-Z0-9_]*|\([a-zA-Z_][a-zA-Z0-9_]*\s+.+\))",
found.group(2),
):
if local_var.group(1).startswith("("):
if m := fullmatch(
r"\(([a-zA-Z_][a-zA-Z0-9_]*)\s+.+\)",
local_var.group(1),
):
children.append(
DocumentSymbol(
m.group(1),
SymbolKind.Variable,
Range(
Position(row, offset + local_var.start() + 1),
Position(
row,
offset
+ local_var.start()
+ 1
+ len(m.string),
),
),
Range(
Position(row, offset + local_var.start() + 1),
Position(
row,
offset
+ local_var.start()
+ 1
+ len(m.group(1)),
),
),
)
)
else:
assert isinstance(active_let.children, list)
active_let.children.append(
DocumentSymbol(
local_var.group(1),
SymbolKind.Variable,
Range(
Position(row, offset + local_var.start()),
Position(row, offset + local_var.end()),
),
Range(
Position(row, offset + local_var.start()),
Position(row, offset + local_var.end()),
),
)
)
def _parse_proc(self, lines: list[str], uri: str) -> None:
for row, line in enumerate(lines):
for found in finditer(
r"(\(\s*procedure|\bprocedure\()(\s+)([a-zA-Z_][a-zA-Z0-9_]*)\((.*)\)",
line,
):
start = Position(row, found.start())
end = find_end(start, lines)
if "@option" in found.group(4) and "@key" in found.group(4):
self.publish_diagnostics(
uri,
[
Diagnostic(
Range(start, Position(row, len(line))),
"`@key` and `@option` used in same definition",
severity=DiagnosticSeverity.Error,
)
],
)
return
args: list[DocumentSymbol] = []
kwargs: list[DocumentSymbol] = []
rest: list[DocumentSymbol] = []
params_start = found.end() - len(found.group(4))
warning(found.group(4))
for part in finditer(
rf"(@(option|key)(\s\(\w+\s+.+\))+|@rest \w+|\"[{''.join(dt.value for dt in SkillDataType)}]+\"|(\w+\s*))",
found.group(4),
):
info(part.group(1))
if part.group(1).startswith("@rest"):
rest_var_name = part.group(1).split()[1]
rest_var_range = Range(
Position(
row,
params_start + part.end() - len(rest_var_name),
),
Position(row, params_start + part.end()),
)
rest.append(
DocumentSymbol(
rest_var_name,
kind=SymbolKind.Variable,
range=rest_var_range,
selection_range=rest_var_range,
)
)
elif part.group(1).startswith("@"):
for kwarg in finditer(r"(\((\w+)\s+[^\)]+\))", part.group(1)):
kwargs.append(
DocumentSymbol(
kwarg.group(2),
kind=SymbolKind.Variable,
range=Range(
Position(
row,
params_start + part.start() + kwarg.start(),
),
Position(
row,
params_start + part.start() + kwarg.end(),
),
),
selection_range=Range(
Position(
row,
params_start + part.start() + kwarg.start(),
),
Position(
row,
params_start
+ part.start()
+ kwarg.start()
+ len(kwarg.group(2)),
),
),
)
)
elif fullmatch(
rf'"[{"".join(dt.value for dt in SkillDataType)}]+"',
part.group(1),
):
if not (
len(args) + len(kwargs) + len(rest)
== len(part.group(1)) - 2
):
self.publish_diagnostics(
uri,
[
Diagnostic(
Range(start, Position(row, len(line))),
"type info length mismatches number of arguments",
severity=DiagnosticSeverity.Error,
)
],
)
return
for char, arg in zip(
part.group(1)[1:-1], chain(args, rest, kwargs)
):
typ = SkillDataType(char)
arg.detail = f"{typ.value}_"
break
else:
for arg in finditer(r"(\w+)", part.group(1)):
arg_range = Range(
Position(
row,
params_start + part.start() + arg.start() - 1,
),
Position(
row,
params_start + part.start() + arg.end() - 1,
),
)
args.append(
DocumentSymbol(
arg.group(1),
kind=SymbolKind.Variable,
range=arg_range,
selection_range=arg_range,
)
)
self.procs.append(
DocumentSymbol(
found.group(3),
kind=SymbolKind.Function,
range=Range(start, end),
selection_range=Range(start, Position(row, len(line))),
children=args + rest + kwargs,
)
)
def _hint_let(self) -> Generator[InlayHint, None, None]:
for let in self.lets:
if let.children:
for child in let.children:
yield InlayHint(child.selection_range.end, "|l")
def _hint_proc(self) -> Generator[InlayHint, None, None]:
for proc in self.procs:
warning(proc)
if proc.children:
for child in proc.children:
yield InlayHint(child.selection_range.end, "|l")
if child.detail:
yield InlayHint(child.selection_range.start, child.detail)
def _hint_globals(self) -> Generator[InlayHint, None, None]:
for glbl in self.globals:
yield InlayHint(glbl.selection_range.end, "|g")
def hint(self, doc: TextDocument, area: Range) -> list[InlayHint]:
hints: list[InlayHint] = []
hints.extend(self._hint_proc())
hints.extend(self._hint_let())
hints.extend(self._hint_globals())
return hints
server = SkillLanguageServer("skillls", "v0.3")
@server.feature(TEXT_DOCUMENT_DID_SAVE)
@server.feature(TEXT_DOCUMENT_DID_OPEN)
@server.feature(TEXT_DOCUMENT_DID_CHANGE)
def on_open(ls: SkillLanguageServer, params: DidSaveTextDocumentParams) -> None:
doc = server.workspace.get_text_document(params.text_document.uri)
if not ls.diagnose(doc):
ls.parse(doc)
ls.lsp.send_request_async(WORKSPACE_INLAY_HINT_REFRESH)
@server.feature(TEXT_DOCUMENT_INLAY_HINT)
def inlay_hints(ls: SkillLanguageServer, params: InlayHintParams) -> list[InlayHint]:
doc = server.workspace.get_text_document(params.text_document.uri)
return ls.hint(doc, params.range)
@server.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL)
def doc_symbols(
ls: SkillLanguageServer,
params: DocumentSymbolParams,
) -> list[DocumentSymbol]:
return ls.procs + ls.lets + ls.defs + ls.globals
def main():
server.start_io()