[gemma4] refactor using treesitter

This commit is contained in:
2026-06-20 10:56:52 +02:00
parent 49f0f23a54
commit d600c0a8ca
10 changed files with 141 additions and 434 deletions
-72
View File
@@ -1,72 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from lsprotocol.types import Position, Range
class SyntaxError(Exception):
pass
class ParenMismatchErrorKind(Enum):
TooManyClosed = "Found too many closing parens"
TooManyOpened = "Found too many open parens"
@dataclass
class ParenMismatchError(SyntaxError):
kind: ParenMismatchErrorKind
loc: Range
def _check_for_matching_parens(content: str) -> list[Exception]:
excs: list[Exception] = []
opened = 0
line = 0
col = 0
last_open: Position = Position(0, 0)
for char in content:
match char:
case "(":
opened += 1
last_open = Position(line, col)
case ")":
opened -= 1
if opened < 0:
excs.append(
ParenMismatchError(
ParenMismatchErrorKind.TooManyClosed,
Range(Position(line, col), Position(line, col + 1)),
)
)
opened = 0
case "\n":
line += 1
col = -1
case _:
pass
col += 1
if opened > 0:
excs.append(
ParenMismatchError(
ParenMismatchErrorKind.TooManyOpened,
Range(last_open, Position(last_open.line, last_open.character + 1)),
)
)
return excs
def check_content_for_errors(clean_content: str) -> None:
excs: list[Exception] = []
excs.extend(_check_for_matching_parens(clean_content))
if excs:
raise ExceptionGroup("", excs)
-192
View File
@@ -1,192 +0,0 @@
from copy import copy
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
from pprint import pformat
from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
from re import MULTILINE, compile as recompile, finditer
from pygls.workspace import TextDocument
from skillls.checker import check_content_for_errors
from skillls.types import Node, NodeKind
logger = getLogger(__name__)
@dataclass
class ParserCleanerState:
in_comment: bool = False
in_string: bool = False
NODE_KIND_OPTIONS = "|".join(k.value for k in NodeKind)
NAMESPACE_STARTERS = recompile(
(rf"(\(\s*(?P<typ>{NODE_KIND_OPTIONS})\b|\b(?P<ctyp>{NODE_KIND_OPTIONS})\()"),
MULTILINE,
)
def clean_content(content: str) -> str:
content_cleaned = ""
state = ParserCleanerState()
for cix, char in enumerate(content):
match (content[cix], state):
case ";", ParserCleanerState(in_comment=False, in_string=False):
state.in_comment = True
case '"', ParserCleanerState(in_comment=False):
if content[cix - 1] != "\\":
state.in_string = not state.in_string
content_cleaned += char
case "\n", ParserCleanerState(in_comment=True):
state.in_comment = False
content_cleaned += char
case _, ParserCleanerState(in_comment=False, in_string=False):
content_cleaned += char
case _, ParserCleanerState(in_comment=False, in_string=True):
content_cleaned += " "
case _:
pass
return content_cleaned
def build_node_hierarchy(nodes: list[Node]) -> list[Node]:
to_be_sorted = copy(nodes)
sorted: list[Node] = []
while to_be_sorted:
node_to_sort = to_be_sorted.pop(0)
for sorted_node in sorted:
if sorted_node.should_contain(node_to_sort):
sorted_node.add_child(node_to_sort)
break
else:
sorted.append(node_to_sort)
return sorted
def find_scopes(content_cleaned: str, scope_prefix: str = "") -> list[Node]:
ret: list[Node] = []
for found in NAMESPACE_STARTERS.finditer(content_cleaned):
partial = content_cleaned[found.end() :]
open_brackets = 1
offset = 0
for offset, char in enumerate(partial):
match char:
case "(":
open_brackets += 1
case ")":
open_brackets -= 1
if open_brackets == 0:
break
case _:
pass
pre_lines = content_cleaned[: found.start()].splitlines()
start_line = len(pre_lines) - (
1 if pre_lines[-1] != "" and pre_lines[-1].strip() == "" else 0
)
start_char = len(pre_lines[-1])
inner_lines = content_cleaned[
found.start() : found.end() + offset + 1
].splitlines()
end_line = start_line + len(inner_lines) - 1
end_char = len(inner_lines[-1])
kind = NodeKind(found.group("typ") or found.group("ctyp"))
loc = Range(Position(start_line, start_char), Position(end_line, end_char))
node = Node(
node=f"{scope_prefix}.{kind.value}_{len([n for n in ret if n.kind == kind])}",
kind=kind,
location=loc,
)
ret.append(node)
next = found.end()
# allowed scoped locals syntax
# function(pos1 pos2)
# function(pos1 (pos2 default))
# function(pos1 @rest args)
# function(pos1 @key (kwarg1 default1) (kwarg2 default2))
while content_cleaned[next] != "(":
if content_cleaned[next] == "\n":
start_line += 1
start_char = 0
next += 1
start_char += 1
next += 1
last = 0
for positional in finditer(
r"(?P<leading>\s*)(?P<local>\w+|\(\w+\b[^)]*\))(?P<trailing>\s*)",
content_cleaned[next:],
):
if positional.start() != last:
logger.debug(
f"found ({positional}), but last ({last}) != ({positional.start()})"
)
break
last = positional.end()
leading_nls = positional.group("leading").count("\n")
inner_nls = positional.group("local").count("\n")
trailing_nls = positional.group("trailing").count("\n")
local_name = positional.group("local").split()[0]
local = DocumentSymbol(
name=local_name,
kind=SymbolKind.Variable,
range=Range(
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char,
),
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char + len(local_name),
),
),
selection_range=Range(
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char,
),
Position(
start_line + leading_nls,
len(positional.group("leading")) + start_char + len(local_name),
),
),
)
node.symbols[local_name] = local
start_line += leading_nls + inner_nls + trailing_nls
start_char += len(positional.group(0))
# other cases
logger.debug(pformat(node))
return build_node_hierarchy(ret)
def parse_file(file: TextDocument) -> list[Node]:
content = file.source
content_cleaned = clean_content(content)
check_content_for_errors(content_cleaned)
return find_scopes(content_cleaned, scope_prefix=Path(file.path).stem)
+22 -34
View File
@@ -27,8 +27,7 @@ from lsprotocol.types import (
from pygls.lsp.server import LanguageServer
from skillls.checker import ParenMismatchError
from skillls.helpers import parse_file
from skillls.parser import SkillParser
from skillls.types import URI, Node
basicConfig(
@@ -44,7 +43,7 @@ class SkillLanguageServer(LanguageServer):
ws_files: set[URI]
opened_files: set[URI]
scopes: dict[URI, list[Node]]
errs: dict[URI, ExceptionGroup]
diagnostics: dict[URI, list[Diagnostic]]
def __init__(
self,
@@ -56,25 +55,14 @@ class SkillLanguageServer(LanguageServer):
super().__init__(name, version, text_document_sync_kind, notebook_document_sync)
self.ws_files = set()
self.opened_files = set()
self.scopes = {}
self.errs = {}
self.scopes: dict[URI, list[DocumentSymbol]] = {}
self.diagnostics: dict[URI, list[Diagnostic]] = {}
self.parser = SkillParser()
def update_diagnostics(self) -> None:
for uri in self.opened_files:
diags: list[Diagnostic] = []
if eg := self.errs.get(uri):
for exc in eg.exceptions:
match exc:
case ParenMismatchError():
diags.append(
Diagnostic(
message=f"[skill_ls] {Path.from_uri(uri).name}:{exc.loc.start.line} {exc.kind.value}",
severity=DiagnosticSeverity.Error,
range=exc.loc,
)
)
diags = self.diagnostics.get(uri, [])
# if diags:
self.text_document_publish_diagnostics(
PublishDiagnosticsParams(
uri=uri,
@@ -105,11 +93,12 @@ def lsp_initialize(server: SkillLanguageServer, params: InitializeParams) -> Non
server.ws_files.add(uri)
try:
server.scopes[uri] = parse_file(server.workspace.get_text_document(uri))
if server.errs.get(uri):
del server.errs[uri]
except ExceptionGroup as eg:
server.errs[uri] = eg
text_doc = server.workspace.get_text_document(uri)
symbols, diagnostics = server.parser.parse_document(text_doc)
server.scopes[uri] = symbols
server.diagnostics[uri] = diagnostics
except Exception as e:
logger.error(f"Error initializing file {uri}: {e}")
@server.feature(TEXT_DOCUMENT_DID_OPEN)
@@ -128,13 +117,12 @@ def on_close(server: SkillLanguageServer, params: DidCloseTextDocumentParams) ->
@server.feature(TEXT_DOCUMENT_DID_SAVE)
def on_change(server: SkillLanguageServer, params: DidChangeTextDocumentParams) -> None:
try:
server.scopes[params.text_document.uri] = parse_file(
server.workspace.get_text_document(params.text_document.uri)
)
if server.errs.get(params.text_document.uri):
del server.errs[params.text_document.uri]
except ExceptionGroup as eg:
server.errs[params.text_document.uri] = eg
text_doc = server.workspace.get_text_document(params.text_document.uri)
symbols, diagnostics = server.parser.parse_document(text_doc)
server.scopes[params.text_document.uri] = symbols
server.diagnostics[params.text_document.uri] = diagnostics
except Exception as e:
logger.error(f"Error changing file {params.text_document.uri}: {e}")
server.update_diagnostics()
@@ -143,13 +131,13 @@ def on_change(server: SkillLanguageServer, params: DidChangeTextDocumentParams)
def on_inlay(server: SkillLanguageServer, params: InlayHintParams) -> list[InlayHint]:
hints: list[InlayHint] = []
uri = params.text_document.uri
for node in server.scopes.get(uri, []):
for symbol in server.scopes.get(uri, []):
hints.append(
InlayHint(
label=node.node,
label=symbol.name,
kind=InlayHintKind.Type,
padding_left=True,
position=node.location.end,
position=symbol.range.end,
)
)
@@ -160,7 +148,7 @@ def on_inlay(server: SkillLanguageServer, params: InlayHintParams) -> list[Inlay
def on_symbols(
server: SkillLanguageServer, params: DocumentSymbolParams
) -> list[DocumentSymbol] | None:
return [node.as_doc_symbol() for node in server.scopes[params.text_document.uri]]
return server.scopes[params.text_document.uri]
def main():
+2 -1
View File
@@ -34,7 +34,7 @@ class SkillParser:
# Tree-sitter parsing
tree = self.parser.parse(bytes(content, "utf8"))
diagnostics: list[Diagnostic] = []
diagnostics: list[Diagnostic] = []
symbols: list[DocumentSymbol] = []
# Traverse the root node to collect errors and symbols
@@ -75,6 +75,7 @@ class SkillParser:
if self._is_symbol_node(node):
symbol = self._create_document_symbol(node, content)
if symbol:
symbols.append(symbol)
# 3. Continue traversal - push children in reverse order to maintain original DFS order