skill-ls/skillls/helpers.py

193 lines
5.9 KiB
Python

from copy import copy
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
from pprint import pformat
from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
from re import MULTILINE, compile as recompile, finditer
from pygls.workspace import TextDocument
from skillls.checker import check_content_for_errors
from skillls.types import URI, Node, NodeKind
logger = getLogger(__name__)
@dataclass
class ParserCleanerState:
in_comment: bool = False
in_string: bool = False
NODE_KIND_OPTIONS = "|".join(k.value for k in NodeKind)
NAMESPACE_STARTERS = recompile(
(rf"(\(\s*(?P<typ>{NODE_KIND_OPTIONS})\b|\b(?P<ctyp>{NODE_KIND_OPTIONS})\()"),
MULTILINE,
)
def clean_content(content: str) -> str:
content_cleaned = ""
state = ParserCleanerState()
for cix, char in enumerate(content):
match (content[cix], state):
case ";", ParserCleanerState(in_comment=False, in_string=False):
state.in_comment = True
case '"', ParserCleanerState(in_comment=False):
if content[cix - 1] != "\\":
state.in_string = not state.in_string
content_cleaned += char
case "\n", ParserCleanerState(in_comment=True):
state.in_comment = False
content_cleaned += char
case _, ParserCleanerState(in_comment=False, in_string=False):
content_cleaned += char
case _, ParserCleanerState(in_comment=False, in_string=True):
content_cleaned += " "
case _:
pass
return content_cleaned
def build_node_hierarchy(nodes: list[Node]) -> list[Node]:
to_be_sorted = copy(nodes)
sorted: list[Node] = []
while to_be_sorted:
node_to_sort = to_be_sorted.pop(0)
for sorted_node in sorted:
if sorted_node.should_contain(node_to_sort):
sorted_node.add_child(node_to_sort)
break
else:
sorted.append(node_to_sort)
return sorted
def find_scopes(content_cleaned: str, scope_prefix: str = "") -> list[Node]:
ret: list[Node] = []
for found in NAMESPACE_STARTERS.finditer(content_cleaned):
partial = content_cleaned[found.end() :]
open_brackets = 1
offset = 0
for offset, char in enumerate(partial):
match char:
case "(":
open_brackets += 1
case ")":
open_brackets -= 1
if open_brackets == 0:
break
case _:
pass
pre_lines = content_cleaned[: found.start()].splitlines()
start_line = len(pre_lines) - (
1 if pre_lines[-1] != "" and pre_lines[-1].strip() == "" else 0
)
start_char = len(pre_lines[-1])
inner_lines = content_cleaned[
found.start() : found.end() + offset + 1
].splitlines()
end_line = start_line + len(inner_lines) - 1
end_char = len(inner_lines[-1])
kind = NodeKind(found.group("typ") or found.group("ctyp"))
loc = Range(Position(start_line, start_char), Position(end_line, end_char))
node = Node(
node=f"{scope_prefix}.{kind.value}_{len([n for n in ret if n.kind == kind])}",
kind=kind,
location=loc,
)
ret.append(node)
next = found.end()
# allowed scoped locals syntax
# function(pos1 pos2)
# function(pos1 (pos2 default))
# function(pos1 @rest args)
# function(pos1 @key (kwarg1 default1) (kwarg2 default2))
while content_cleaned[next] != "(":
if content_cleaned[next] == "\n":
start_line += 1
start_char = 0
next += 1
start_char += 1
next += 1
last = 0
for positional in finditer(
r"(?P<leading>\s*)(?P<local>\w+|\(\w+\b[^)]*\))(?P<trailing>\s*)",
content_cleaned[next:],
):
if positional.start() != last:
logger.debug(
f"found ({positional}), but last ({last}) != ({positional.start()})"
)
break
last = positional.end()
leading_nls = positional.group("leading").count("\n")
inner_nls = positional.group("local").count("\n")
trailing_nls = positional.group("trailing").count("\n")
local_name = positional.group("local").split()[0]
local = DocumentSymbol(
name=local_name,
kind=SymbolKind.Variable,
range=Range(
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char,
),
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char + len(local_name),
),
),
selection_range=Range(
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char,
),
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char + len(local_name),
),
),
)
node.symbols[local_name] = local
start_line += leading_nls + inner_nls + trailing_nls
start_char += len(positional.group(0))
# other cases
logger.debug(pformat(node))
return build_node_hierarchy(ret)
def parse_file(file: TextDocument) -> list[Node]:
content = file.source
content_cleaned = clean_content(content)
check_content_for_errors(content_cleaned)
return find_scopes(content_cleaned, scope_prefix=Path(file.path).stem)