from copy import copy from dataclasses import dataclass from logging import getLogger from pathlib import Path from pprint import pformat from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind from re import MULTILINE, compile as recompile, finditer from pygls.workspace import TextDocument from skillls.checker import check_content_for_errors from skillls.types import URI, Node, NodeKind logger = getLogger(__name__) @dataclass class ParserCleanerState: in_comment: bool = False in_string: bool = False NODE_KIND_OPTIONS = "|".join(k.value for k in NodeKind) NAMESPACE_STARTERS = recompile( (rf"(\(\s*(?P{NODE_KIND_OPTIONS})\b|\b(?P{NODE_KIND_OPTIONS})\()"), MULTILINE, ) def clean_content(content: str) -> str: content_cleaned = "" state = ParserCleanerState() for cix, char in enumerate(content): match (content[cix], state): case ";", ParserCleanerState(in_comment=False, in_string=False): state.in_comment = True case '"', ParserCleanerState(in_comment=False): if content[cix - 1] != "\\": state.in_string = not state.in_string content_cleaned += char case "\n", ParserCleanerState(in_comment=True): state.in_comment = False content_cleaned += char case _, ParserCleanerState(in_comment=False, in_string=False): content_cleaned += char case _, ParserCleanerState(in_comment=False, in_string=True): content_cleaned += " " case _: pass return content_cleaned def build_node_hierarchy(nodes: list[Node]) -> list[Node]: to_be_sorted = copy(nodes) sorted: list[Node] = [] while to_be_sorted: node_to_sort = to_be_sorted.pop(0) for sorted_node in sorted: if sorted_node.should_contain(node_to_sort): sorted_node.add_child(node_to_sort) break else: sorted.append(node_to_sort) return sorted def find_scopes(content_cleaned: str, scope_prefix: str = "") -> list[Node]: ret: list[Node] = [] for found in NAMESPACE_STARTERS.finditer(content_cleaned): partial = content_cleaned[found.end() :] open_brackets = 1 offset = 0 for offset, char in enumerate(partial): match char: case "(": open_brackets += 1 case ")": open_brackets -= 1 if open_brackets == 0: break case _: pass pre_lines = content_cleaned[: found.start()].splitlines() start_line = len(pre_lines) - ( 1 if pre_lines[-1] != "" and pre_lines[-1].strip() == "" else 0 ) start_char = len(pre_lines[-1]) inner_lines = content_cleaned[ found.start() : found.end() + offset + 1 ].splitlines() end_line = start_line + len(inner_lines) - 1 end_char = len(inner_lines[-1]) kind = NodeKind(found.group("typ") or found.group("ctyp")) loc = Range(Position(start_line, start_char), Position(end_line, end_char)) node = Node( node=f"{scope_prefix}.{kind.value}_{len([n for n in ret if n.kind == kind])}", kind=kind, location=loc, ) ret.append(node) next = found.end() # allowed scoped locals syntax # function(pos1 pos2) # function(pos1 (pos2 default)) # function(pos1 @rest args) # function(pos1 @key (kwarg1 default1) (kwarg2 default2)) while content_cleaned[next] != "(": if content_cleaned[next] == "\n": start_line += 1 start_char = 0 next += 1 start_char += 1 next += 1 last = 0 for positional in finditer( r"(?P\s*)(?P\w+|\(\w+\b[^)]*\))(?P\s*)", content_cleaned[next:], ): if positional.start() != last: logger.debug( f"found ({positional}), but last ({last}) != ({positional.start()})" ) break last = positional.end() leading_nls = positional.group("leading").count("\n") inner_nls = positional.group("local").count("\n") trailing_nls = positional.group("trailing").count("\n") local_name = positional.group("local").split()[0] local = DocumentSymbol( name=local_name, kind=SymbolKind.Variable, range=Range( Position( start_line + leading_nls, len(positional.group("leading")) + start_char, ), Position( start_line + leading_nls, len(positional.group("leading")) + start_char + len(local_name), ), ), selection_range=Range( Position( start_line + leading_nls, len(positional.group("leading")) + start_char, ), Position( start_line + leading_nls, len(positional.group("leading")) + start_char + len(local_name), ), ), ) node.symbols[local_name] = local start_line += leading_nls + inner_nls + trailing_nls start_char += len(positional.group(0)) # other cases logger.debug(pformat(node)) return build_node_hierarchy(ret) def parse_file(file: TextDocument) -> list[Node]: content = file.source content_cleaned = clean_content(content) check_content_for_errors(content_cleaned) return find_scopes(content_cleaned, scope_prefix=Path(file.path).stem)