skill-ls/skillls/main.py

154 lines
4.0 KiB
Python

from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass, field
from difflib import Differ
from itertools import chain
from logging import DEBUG, INFO, basicConfig, debug, error, getLogger, info, warning
from re import findall, finditer, fullmatch, match as rematch
import re
from time import time
from cattrs import Converter
from lsprotocol.types import (
INLAY_HINT_RESOLVE,
TEXT_DOCUMENT_DID_CHANGE,
TEXT_DOCUMENT_DID_OPEN,
TEXT_DOCUMENT_DID_SAVE,
TEXT_DOCUMENT_DOCUMENT_SYMBOL,
TEXT_DOCUMENT_HOVER,
TEXT_DOCUMENT_INLAY_HINT,
WORKSPACE_INLAY_HINT_REFRESH,
WORKSPACE_SEMANTIC_TOKENS_REFRESH,
CompletionItem,
Diagnostic,
DiagnosticSeverity,
DidChangeTextDocumentParams,
DidOpenTextDocumentParams,
DidSaveTextDocumentParams,
DocumentSymbol,
DocumentSymbolParams,
Hover,
HoverParams,
InlayHint,
InlayHintKind,
InlayHintParams,
MessageType,
NotebookDocumentSyncOptions,
Position,
Range,
SymbolKind,
TextDocumentContentChangeEvent,
TextDocumentContentChangeEvent_Type1,
TextDocumentSyncKind,
)
from pygls.protocol import LanguageServerProtocol, default_converter
from tree_sitter_skill import language as skill_lang
from tree_sitter import Language, Node, Parser, Query, Tree
from pygls.server import LanguageServer
from pygls.workspace import TextDocument
from .cache import Cache
SKILL_LANG = Language(skill_lang())
SKILL_PARSER = Parser(SKILL_LANG)
URI = str
basicConfig(
filename="skillls.log",
filemode="w",
level=DEBUG,
format="%(asctime)s [%(levelname)s]: %(message)s",
)
logger = getLogger()
cache: Cache[str, CompletionItem] = Cache()
def in_range(what: Position, area: Range) -> bool:
return (what >= area.start) and (what <= area.end)
def find_end(start: Position, lines: list[str]) -> Position:
count = 0
in_str: bool = False
last = ""
for row, line in enumerate(lines[start.line :]):
if row == 0:
line = line[start.character :]
row += start.character
for col, char in enumerate(line[start.character :] if row == 0 else line):
match char:
case "(":
if not in_str:
count += 1
case ")":
if not in_str:
if count > 0:
count -= 1
if count == 0:
return Position(start.line + row, col)
case '"':
if not (in_str and last == "\\"):
in_str = not in_str
case _:
last = char
last = char
error(f"did not fin end for start at {start}")
return Position(len(lines), len(lines[-1]))
@dataclass(frozen=True)
class Environment:
range: Range
@dataclass(frozen=True)
class LetEnvironment(Environment):
locals: set[str] = field(default_factory=set)
def offset_range(range: Range, lines: int, cols: int = 0) -> Range:
return Range(
Position(
range.start.line + lines,
range.start.character + cols,
),
Position(
range.end.line + lines,
range.end.character + cols,
),
)
class SkillLanguageServer(LanguageServer):
def __init__(
self,
name: str,
version: str,
loop=None,
protocol_cls: type[LanguageServerProtocol] = LanguageServerProtocol,
converter_factory: Callable[[], Converter] = default_converter,
text_document_sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental,
notebook_document_sync: NotebookDocumentSyncOptions | None = None,
max_workers: int = 2,
):
super().__init__(
name,
version,
loop,
protocol_cls,
converter_factory,
text_document_sync_kind,
notebook_document_sync,
max_workers,
)
def main():
server.start_io()