193 lines
5.9 KiB
Python
193 lines
5.9 KiB
Python
from copy import copy
|
|
from dataclasses import dataclass
|
|
from logging import getLogger
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
|
|
from re import MULTILINE, compile as recompile, finditer
|
|
|
|
from pygls.workspace import TextDocument
|
|
|
|
from skillls.checker import check_content_for_errors
|
|
from skillls.types import URI, Node, NodeKind
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ParserCleanerState:
|
|
in_comment: bool = False
|
|
in_string: bool = False
|
|
|
|
|
|
NODE_KIND_OPTIONS = "|".join(k.value for k in NodeKind)
|
|
NAMESPACE_STARTERS = recompile(
|
|
(rf"(\(\s*(?P<typ>{NODE_KIND_OPTIONS})\b|\b(?P<ctyp>{NODE_KIND_OPTIONS})\()"),
|
|
MULTILINE,
|
|
)
|
|
|
|
|
|
def clean_content(content: str) -> str:
|
|
content_cleaned = ""
|
|
state = ParserCleanerState()
|
|
|
|
for cix, char in enumerate(content):
|
|
match (content[cix], state):
|
|
case ";", ParserCleanerState(in_comment=False, in_string=False):
|
|
state.in_comment = True
|
|
case '"', ParserCleanerState(in_comment=False):
|
|
if content[cix - 1] != "\\":
|
|
state.in_string = not state.in_string
|
|
content_cleaned += char
|
|
case "\n", ParserCleanerState(in_comment=True):
|
|
state.in_comment = False
|
|
content_cleaned += char
|
|
case _, ParserCleanerState(in_comment=False, in_string=False):
|
|
content_cleaned += char
|
|
|
|
case _, ParserCleanerState(in_comment=False, in_string=True):
|
|
content_cleaned += " "
|
|
case _:
|
|
pass
|
|
|
|
return content_cleaned
|
|
|
|
|
|
def build_node_hierarchy(nodes: list[Node]) -> list[Node]:
|
|
to_be_sorted = copy(nodes)
|
|
sorted: list[Node] = []
|
|
|
|
while to_be_sorted:
|
|
node_to_sort = to_be_sorted.pop(0)
|
|
|
|
for sorted_node in sorted:
|
|
if sorted_node.should_contain(node_to_sort):
|
|
sorted_node.add_child(node_to_sort)
|
|
break
|
|
|
|
else:
|
|
sorted.append(node_to_sort)
|
|
|
|
return sorted
|
|
|
|
|
|
def find_scopes(content_cleaned: str, scope_prefix: str = "") -> list[Node]:
|
|
ret: list[Node] = []
|
|
|
|
for found in NAMESPACE_STARTERS.finditer(content_cleaned):
|
|
partial = content_cleaned[found.end() :]
|
|
open_brackets = 1
|
|
offset = 0
|
|
for offset, char in enumerate(partial):
|
|
match char:
|
|
case "(":
|
|
open_brackets += 1
|
|
case ")":
|
|
open_brackets -= 1
|
|
|
|
if open_brackets == 0:
|
|
break
|
|
|
|
case _:
|
|
pass
|
|
|
|
pre_lines = content_cleaned[: found.start()].splitlines()
|
|
start_line = len(pre_lines) - (
|
|
1 if pre_lines[-1] != "" and pre_lines[-1].strip() == "" else 0
|
|
)
|
|
start_char = len(pre_lines[-1])
|
|
|
|
inner_lines = content_cleaned[
|
|
found.start() : found.end() + offset + 1
|
|
].splitlines()
|
|
end_line = start_line + len(inner_lines) - 1
|
|
end_char = len(inner_lines[-1])
|
|
|
|
kind = NodeKind(found.group("typ") or found.group("ctyp"))
|
|
loc = Range(Position(start_line, start_char), Position(end_line, end_char))
|
|
|
|
node = Node(
|
|
node=f"{scope_prefix}.{kind.value}_{len([n for n in ret if n.kind == kind])}",
|
|
kind=kind,
|
|
location=loc,
|
|
)
|
|
ret.append(node)
|
|
|
|
next = found.end()
|
|
|
|
# allowed scoped locals syntax
|
|
# function(pos1 pos2)
|
|
# function(pos1 (pos2 default))
|
|
# function(pos1 @rest args)
|
|
# function(pos1 @key (kwarg1 default1) (kwarg2 default2))
|
|
|
|
while content_cleaned[next] != "(":
|
|
if content_cleaned[next] == "\n":
|
|
start_line += 1
|
|
start_char = 0
|
|
next += 1
|
|
start_char += 1
|
|
|
|
next += 1
|
|
last = 0
|
|
|
|
for positional in finditer(
|
|
r"(?P<leading>\s*)(?P<local>\w+|\(\w+\b[^)]*\))(?P<trailing>\s*)",
|
|
content_cleaned[next:],
|
|
):
|
|
if positional.start() != last:
|
|
logger.debug(
|
|
f"found ({positional}), but last ({last}) != ({positional.start()})"
|
|
)
|
|
break
|
|
|
|
last = positional.end()
|
|
|
|
leading_nls = positional.group("leading").count("\n")
|
|
inner_nls = positional.group("local").count("\n")
|
|
trailing_nls = positional.group("trailing").count("\n")
|
|
|
|
local_name = positional.group("local").split()[0]
|
|
local = DocumentSymbol(
|
|
name=local_name,
|
|
kind=SymbolKind.Variable,
|
|
range=Range(
|
|
Position(
|
|
start_line + leading_nls,
|
|
len(positional.group("leading")) + start_char,
|
|
),
|
|
Position(
|
|
start_line + leading_nls,
|
|
len(positional.group("leading")) + start_char + len(local_name),
|
|
),
|
|
),
|
|
selection_range=Range(
|
|
Position(
|
|
start_line + leading_nls,
|
|
len(positional.group("leading")) + start_char,
|
|
),
|
|
Position(
|
|
start_line + leading_nls,
|
|
len(positional.group("leading")) + start_char + len(local_name),
|
|
),
|
|
),
|
|
)
|
|
node.symbols[local_name] = local
|
|
|
|
start_line += leading_nls + inner_nls + trailing_nls
|
|
start_char += len(positional.group(0))
|
|
|
|
# other cases
|
|
|
|
logger.debug(pformat(node))
|
|
return build_node_hierarchy(ret)
|
|
|
|
|
|
def parse_file(file: TextDocument) -> list[Node]:
|
|
content = file.source
|
|
content_cleaned = clean_content(content)
|
|
|
|
check_content_for_errors(content_cleaned)
|
|
|
|
return find_scopes(content_cleaned, scope_prefix=Path(file.path).stem)
|