[gemma4] refactor using treesitter
This commit is contained in:
@@ -1,72 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from lsprotocol.types import Position, Range
|
||||
|
||||
|
||||
|
||||
class SyntaxError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ParenMismatchErrorKind(Enum):
|
||||
TooManyClosed = "Found too many closing parens"
|
||||
TooManyOpened = "Found too many open parens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParenMismatchError(SyntaxError):
|
||||
kind: ParenMismatchErrorKind
|
||||
loc: Range
|
||||
|
||||
|
||||
def _check_for_matching_parens(content: str) -> list[Exception]:
|
||||
excs: list[Exception] = []
|
||||
|
||||
opened = 0
|
||||
line = 0
|
||||
col = 0
|
||||
last_open: Position = Position(0, 0)
|
||||
for char in content:
|
||||
match char:
|
||||
case "(":
|
||||
opened += 1
|
||||
last_open = Position(line, col)
|
||||
|
||||
case ")":
|
||||
opened -= 1
|
||||
if opened < 0:
|
||||
excs.append(
|
||||
ParenMismatchError(
|
||||
ParenMismatchErrorKind.TooManyClosed,
|
||||
Range(Position(line, col), Position(line, col + 1)),
|
||||
)
|
||||
)
|
||||
opened = 0
|
||||
case "\n":
|
||||
line += 1
|
||||
col = -1
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
col += 1
|
||||
|
||||
if opened > 0:
|
||||
excs.append(
|
||||
ParenMismatchError(
|
||||
ParenMismatchErrorKind.TooManyOpened,
|
||||
Range(last_open, Position(last_open.line, last_open.character + 1)),
|
||||
)
|
||||
)
|
||||
|
||||
return excs
|
||||
|
||||
|
||||
def check_content_for_errors(clean_content: str) -> None:
|
||||
excs: list[Exception] = []
|
||||
|
||||
excs.extend(_check_for_matching_parens(clean_content))
|
||||
|
||||
if excs:
|
||||
raise ExceptionGroup("", excs)
|
||||
@@ -1,192 +0,0 @@
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
|
||||
from re import MULTILINE, compile as recompile, finditer
|
||||
|
||||
from pygls.workspace import TextDocument
|
||||
|
||||
from skillls.checker import check_content_for_errors
|
||||
from skillls.types import Node, NodeKind
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParserCleanerState:
|
||||
in_comment: bool = False
|
||||
in_string: bool = False
|
||||
|
||||
|
||||
NODE_KIND_OPTIONS = "|".join(k.value for k in NodeKind)
|
||||
NAMESPACE_STARTERS = recompile(
|
||||
(rf"(\(\s*(?P<typ>{NODE_KIND_OPTIONS})\b|\b(?P<ctyp>{NODE_KIND_OPTIONS})\()"),
|
||||
MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def clean_content(content: str) -> str:
|
||||
content_cleaned = ""
|
||||
state = ParserCleanerState()
|
||||
|
||||
for cix, char in enumerate(content):
|
||||
match (content[cix], state):
|
||||
case ";", ParserCleanerState(in_comment=False, in_string=False):
|
||||
state.in_comment = True
|
||||
case '"', ParserCleanerState(in_comment=False):
|
||||
if content[cix - 1] != "\\":
|
||||
state.in_string = not state.in_string
|
||||
content_cleaned += char
|
||||
case "\n", ParserCleanerState(in_comment=True):
|
||||
state.in_comment = False
|
||||
content_cleaned += char
|
||||
case _, ParserCleanerState(in_comment=False, in_string=False):
|
||||
content_cleaned += char
|
||||
|
||||
case _, ParserCleanerState(in_comment=False, in_string=True):
|
||||
content_cleaned += " "
|
||||
case _:
|
||||
pass
|
||||
|
||||
return content_cleaned
|
||||
|
||||
|
||||
def build_node_hierarchy(nodes: list[Node]) -> list[Node]:
|
||||
to_be_sorted = copy(nodes)
|
||||
sorted: list[Node] = []
|
||||
|
||||
while to_be_sorted:
|
||||
node_to_sort = to_be_sorted.pop(0)
|
||||
|
||||
for sorted_node in sorted:
|
||||
if sorted_node.should_contain(node_to_sort):
|
||||
sorted_node.add_child(node_to_sort)
|
||||
break
|
||||
|
||||
else:
|
||||
sorted.append(node_to_sort)
|
||||
|
||||
return sorted
|
||||
|
||||
|
||||
def find_scopes(content_cleaned: str, scope_prefix: str = "") -> list[Node]:
|
||||
ret: list[Node] = []
|
||||
|
||||
for found in NAMESPACE_STARTERS.finditer(content_cleaned):
|
||||
partial = content_cleaned[found.end() :]
|
||||
open_brackets = 1
|
||||
offset = 0
|
||||
for offset, char in enumerate(partial):
|
||||
match char:
|
||||
case "(":
|
||||
open_brackets += 1
|
||||
case ")":
|
||||
open_brackets -= 1
|
||||
|
||||
if open_brackets == 0:
|
||||
break
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
pre_lines = content_cleaned[: found.start()].splitlines()
|
||||
start_line = len(pre_lines) - (
|
||||
1 if pre_lines[-1] != "" and pre_lines[-1].strip() == "" else 0
|
||||
)
|
||||
start_char = len(pre_lines[-1])
|
||||
|
||||
inner_lines = content_cleaned[
|
||||
found.start() : found.end() + offset + 1
|
||||
].splitlines()
|
||||
end_line = start_line + len(inner_lines) - 1
|
||||
end_char = len(inner_lines[-1])
|
||||
|
||||
kind = NodeKind(found.group("typ") or found.group("ctyp"))
|
||||
loc = Range(Position(start_line, start_char), Position(end_line, end_char))
|
||||
|
||||
node = Node(
|
||||
node=f"{scope_prefix}.{kind.value}_{len([n for n in ret if n.kind == kind])}",
|
||||
kind=kind,
|
||||
location=loc,
|
||||
)
|
||||
ret.append(node)
|
||||
|
||||
next = found.end()
|
||||
|
||||
# allowed scoped locals syntax
|
||||
# function(pos1 pos2)
|
||||
# function(pos1 (pos2 default))
|
||||
# function(pos1 @rest args)
|
||||
# function(pos1 @key (kwarg1 default1) (kwarg2 default2))
|
||||
|
||||
while content_cleaned[next] != "(":
|
||||
if content_cleaned[next] == "\n":
|
||||
start_line += 1
|
||||
start_char = 0
|
||||
next += 1
|
||||
start_char += 1
|
||||
|
||||
next += 1
|
||||
last = 0
|
||||
|
||||
for positional in finditer(
|
||||
r"(?P<leading>\s*)(?P<local>\w+|\(\w+\b[^)]*\))(?P<trailing>\s*)",
|
||||
content_cleaned[next:],
|
||||
):
|
||||
if positional.start() != last:
|
||||
logger.debug(
|
||||
f"found ({positional}), but last ({last}) != ({positional.start()})"
|
||||
)
|
||||
break
|
||||
|
||||
last = positional.end()
|
||||
|
||||
leading_nls = positional.group("leading").count("\n")
|
||||
inner_nls = positional.group("local").count("\n")
|
||||
trailing_nls = positional.group("trailing").count("\n")
|
||||
|
||||
local_name = positional.group("local").split()[0]
|
||||
local = DocumentSymbol(
|
||||
name=local_name,
|
||||
kind=SymbolKind.Variable,
|
||||
range=Range(
|
||||
Position(
|
||||
start_line + leading_nls,
|
||||
len(positional.group("leading")) + start_char,
|
||||
),
|
||||
Position(
|
||||
start_line + leading_nls,
|
||||
len(positional.group("leading")) + start_char + len(local_name),
|
||||
),
|
||||
),
|
||||
selection_range=Range(
|
||||
Position(
|
||||
start_line + leading_nls,
|
||||
len(positional.group("leading")) + start_char,
|
||||
),
|
||||
Position(
|
||||
start_line + leading_nls,
|
||||
len(positional.group("leading")) + start_char + len(local_name),
|
||||
),
|
||||
),
|
||||
)
|
||||
node.symbols[local_name] = local
|
||||
|
||||
start_line += leading_nls + inner_nls + trailing_nls
|
||||
start_char += len(positional.group(0))
|
||||
|
||||
# other cases
|
||||
|
||||
logger.debug(pformat(node))
|
||||
return build_node_hierarchy(ret)
|
||||
|
||||
|
||||
def parse_file(file: TextDocument) -> list[Node]:
|
||||
content = file.source
|
||||
content_cleaned = clean_content(content)
|
||||
|
||||
check_content_for_errors(content_cleaned)
|
||||
|
||||
return find_scopes(content_cleaned, scope_prefix=Path(file.path).stem)
|
||||
+22
-34
@@ -27,8 +27,7 @@ from lsprotocol.types import (
|
||||
from pygls.lsp.server import LanguageServer
|
||||
|
||||
|
||||
from skillls.checker import ParenMismatchError
|
||||
from skillls.helpers import parse_file
|
||||
from skillls.parser import SkillParser
|
||||
from skillls.types import URI, Node
|
||||
|
||||
basicConfig(
|
||||
@@ -44,7 +43,7 @@ class SkillLanguageServer(LanguageServer):
|
||||
ws_files: set[URI]
|
||||
opened_files: set[URI]
|
||||
scopes: dict[URI, list[Node]]
|
||||
errs: dict[URI, ExceptionGroup]
|
||||
diagnostics: dict[URI, list[Diagnostic]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,25 +55,14 @@ class SkillLanguageServer(LanguageServer):
|
||||
super().__init__(name, version, text_document_sync_kind, notebook_document_sync)
|
||||
self.ws_files = set()
|
||||
self.opened_files = set()
|
||||
self.scopes = {}
|
||||
self.errs = {}
|
||||
self.scopes: dict[URI, list[DocumentSymbol]] = {}
|
||||
self.diagnostics: dict[URI, list[Diagnostic]] = {}
|
||||
self.parser = SkillParser()
|
||||
|
||||
def update_diagnostics(self) -> None:
|
||||
for uri in self.opened_files:
|
||||
diags: list[Diagnostic] = []
|
||||
if eg := self.errs.get(uri):
|
||||
for exc in eg.exceptions:
|
||||
match exc:
|
||||
case ParenMismatchError():
|
||||
diags.append(
|
||||
Diagnostic(
|
||||
message=f"[skill_ls] {Path.from_uri(uri).name}:{exc.loc.start.line} {exc.kind.value}",
|
||||
severity=DiagnosticSeverity.Error,
|
||||
range=exc.loc,
|
||||
)
|
||||
)
|
||||
diags = self.diagnostics.get(uri, [])
|
||||
|
||||
# if diags:
|
||||
self.text_document_publish_diagnostics(
|
||||
PublishDiagnosticsParams(
|
||||
uri=uri,
|
||||
@@ -105,11 +93,12 @@ def lsp_initialize(server: SkillLanguageServer, params: InitializeParams) -> Non
|
||||
|
||||
server.ws_files.add(uri)
|
||||
try:
|
||||
server.scopes[uri] = parse_file(server.workspace.get_text_document(uri))
|
||||
if server.errs.get(uri):
|
||||
del server.errs[uri]
|
||||
except ExceptionGroup as eg:
|
||||
server.errs[uri] = eg
|
||||
text_doc = server.workspace.get_text_document(uri)
|
||||
symbols, diagnostics = server.parser.parse_document(text_doc)
|
||||
server.scopes[uri] = symbols
|
||||
server.diagnostics[uri] = diagnostics
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing file {uri}: {e}")
|
||||
|
||||
|
||||
@server.feature(TEXT_DOCUMENT_DID_OPEN)
|
||||
@@ -128,13 +117,12 @@ def on_close(server: SkillLanguageServer, params: DidCloseTextDocumentParams) ->
|
||||
@server.feature(TEXT_DOCUMENT_DID_SAVE)
|
||||
def on_change(server: SkillLanguageServer, params: DidChangeTextDocumentParams) -> None:
|
||||
try:
|
||||
server.scopes[params.text_document.uri] = parse_file(
|
||||
server.workspace.get_text_document(params.text_document.uri)
|
||||
)
|
||||
if server.errs.get(params.text_document.uri):
|
||||
del server.errs[params.text_document.uri]
|
||||
except ExceptionGroup as eg:
|
||||
server.errs[params.text_document.uri] = eg
|
||||
text_doc = server.workspace.get_text_document(params.text_document.uri)
|
||||
symbols, diagnostics = server.parser.parse_document(text_doc)
|
||||
server.scopes[params.text_document.uri] = symbols
|
||||
server.diagnostics[params.text_document.uri] = diagnostics
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing file {params.text_document.uri}: {e}")
|
||||
|
||||
server.update_diagnostics()
|
||||
|
||||
@@ -143,13 +131,13 @@ def on_change(server: SkillLanguageServer, params: DidChangeTextDocumentParams)
|
||||
def on_inlay(server: SkillLanguageServer, params: InlayHintParams) -> list[InlayHint]:
|
||||
hints: list[InlayHint] = []
|
||||
uri = params.text_document.uri
|
||||
for node in server.scopes.get(uri, []):
|
||||
for symbol in server.scopes.get(uri, []):
|
||||
hints.append(
|
||||
InlayHint(
|
||||
label=node.node,
|
||||
label=symbol.name,
|
||||
kind=InlayHintKind.Type,
|
||||
padding_left=True,
|
||||
position=node.location.end,
|
||||
position=symbol.range.end,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -160,7 +148,7 @@ def on_inlay(server: SkillLanguageServer, params: InlayHintParams) -> list[Inlay
|
||||
def on_symbols(
|
||||
server: SkillLanguageServer, params: DocumentSymbolParams
|
||||
) -> list[DocumentSymbol] | None:
|
||||
return [node.as_doc_symbol() for node in server.scopes[params.text_document.uri]]
|
||||
return server.scopes[params.text_document.uri]
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
+2
-1
@@ -34,7 +34,7 @@ class SkillParser:
|
||||
# Tree-sitter parsing
|
||||
tree = self.parser.parse(bytes(content, "utf8"))
|
||||
|
||||
diagnostics: list[Diagnostic] = []
|
||||
diagnostics: list[Diagnostic] = []
|
||||
symbols: list[DocumentSymbol] = []
|
||||
|
||||
# Traverse the root node to collect errors and symbols
|
||||
@@ -75,6 +75,7 @@ class SkillParser:
|
||||
if self._is_symbol_node(node):
|
||||
symbol = self._create_document_symbol(node, content)
|
||||
if symbol:
|
||||
|
||||
symbols.append(symbol)
|
||||
|
||||
# 3. Continue traversal - push children in reverse order to maintain original DFS order
|
||||
|
||||
Reference in New Issue
Block a user