from abc import ABC from dataclasses import dataclass, field from enum import Enum from logging import getLogger import re from pathlib import Path from typing import NamedTuple, Self from lsprotocol.types import ( Diagnostic, DiagnosticSeverity, DocumentSymbol, Position, Range, SymbolKind, ) logger = getLogger(__name__) class Pair(NamedTuple): start: str end: str class SyntaxPair(Enum): Paren = Pair("(", ")") Square = Pair("[", "]") @classmethod def by_start_elem(cls, start: str) -> Self: for option in cls: if option.value[0] == start: return option raise ValueError(f"`{start}` not a valid start character") @classmethod def by_end_elem(cls, end: str) -> Self: for option in cls: if option.value[1] == end: return option raise ValueError(f"`{end}` not a valid end character") def char_range(line: int, char: int) -> Range: return Range(Position(line, char), Position(line, char + 1)) def pair_mismatch(line: int, char: int, msg: str) -> Diagnostic: return Diagnostic( char_range(line, char), msg, severity=DiagnosticSeverity.Error, ) class StackElement(NamedTuple): range: Range elem: SyntaxPair WHITESPACE_OR_PAREN = re.compile(r"(\s|\(|\)|\[|\]|\'\()+") TOKEN_REGEX = re.compile(r"\w[a-zA-Z0-9_]*") NUMBER_REGEX = re.compile(r"\d+(\.\d+)?") OPERATORS = re.compile(r"(->|~>|\+|\-|\*|\/|\=|\|\||\&\&)") @dataclass class TreeToken(ABC): content: str range: Range def String(content: str, range: Range) -> DocumentSymbol: return DocumentSymbol( name=content, range=range, kind=SymbolKind.String, selection_range=range, ) def Operator(content: str, range: Range) -> DocumentSymbol: return DocumentSymbol( name=content, range=range, kind=SymbolKind.Operator, selection_range=range, ) def Number(content: str, range: Range) -> DocumentSymbol: return DocumentSymbol( name=content, range=range, kind=SymbolKind.Number, selection_range=range, ) def Token(content: str, range: Range) -> DocumentSymbol: return DocumentSymbol( name=content, range=range, kind=SymbolKind.Variable, selection_range=range, ) RawIndex = int ColIndex = int LineIndex = int @dataclass class TokenParser: _in_string: bool = False _in_comment: bool = False _token_tree: list[DocumentSymbol] = field(default_factory=list) _current: str = "" _line_indices: list[RawIndex] = field(default_factory=list) def _get_line(self, index: RawIndex) -> tuple[LineIndex, RawIndex]: for line, newline_pos in enumerate(self._line_indices): if index < newline_pos: return line, self._line_indices[line - 1] if line > 0 else 0 return len(self._line_indices), self._line_indices[-1] def _get_range(self, start: RawIndex, end: RawIndex) -> Range: start_line, start_line_index = self._get_line(start) start_col = start - start_line_index - 1 end_line, end_line_index = self._get_line(end) end_col = end - end_line_index - 1 return Range(Position(start_line, start_col), Position(end_line, end_col)) def _parse_string(self, raw: str, index: int) -> int: stop = raw.index('"', index + 1) self._token_tree.append( String(raw[index : stop + 1], self._get_range(index, stop)) ) return stop + 1 def _parse_comment(self, raw: str, index: int) -> int: stop = raw.index("\n", index) # self._token_tree.append(Comment(raw[index:stop], self._get_range(index, stop))) return stop + 1 def _parse_whitespace(self, raw: str, index: int) -> int: if m := WHITESPACE_OR_PAREN.search(raw, index): stop = m.end() else: stop = index + 1 # self._token_tree.append(Whitespace(raw[index:stop])) return stop def _parse_operator(self, raw: str, index: int) -> int: if m := OPERATORS.search(raw, index): stop = m.end() else: stop = index + 1 self._token_tree.append( Operator(raw[index:stop], self._get_range(index, stop - 1)) ) return stop + 1 def _parse_token(self, raw: str, index: int) -> int: if m := TOKEN_REGEX.search(raw, index): stop = m.end() else: stop = index + 1 self._token_tree.append( Token(raw[index:stop], self._get_range(index, stop - 1)) ) return stop def _parse_number(self, raw: str, index: int) -> int: if m := NUMBER_REGEX.search(raw, index): stop = m.end() else: stop = index + 1 self._token_tree.append( Number(raw[index:stop], self._get_range(index, stop - 1)) ) return stop def prepare_content(self, raw: str) -> None: self._line_indices = [i for i, char in enumerate(raw) if char == "\n"] max_index = len(raw) index = 0 while index < max_index: if raw[index] == '"': index = self._parse_string(raw, index) elif raw[index] == ";": index = self._parse_comment(raw, index) elif WHITESPACE_OR_PAREN.match(raw[index : index + 2]): index = self._parse_whitespace(raw, index) elif OPERATORS.match(raw[index]): index = self._parse_operator(raw, index) elif NUMBER_REGEX.match(raw[index]): index = self._parse_number(raw, index) else: index = self._parse_token(raw, index) @dataclass() class IterativeParser: _stack: list[StackElement] = field(default_factory=list) def peek(self) -> StackElement: return self._stack[-1] def pop(self) -> StackElement: return self._stack.pop() def push(self, pair: StackElement) -> None: return self._stack.append(pair) def __call__(self, raw: list[str]) -> list[Diagnostic]: in_string = False errs = [] for line, raw_line in enumerate(raw): for char, raw_char in enumerate(raw_line): match raw_char: case ";": if not in_string: break case '"': in_string = not in_string case "(" | "[": if not in_string: self.push( StackElement( char_range(line, char), SyntaxPair.by_start_elem(raw_char), ) ) case "]" | ")": if not in_string: if not self._stack: errs.append( pair_mismatch( line, char, f"one {raw_char} too much" ) ) continue expected = SyntaxPair.by_end_elem(raw_char) elem = self._stack.pop() if elem.elem == expected: continue if self._stack and self._stack[-1].elem == expected: errs.append( pair_mismatch( line, char, f"unclosed {elem.elem.value.start}" ) ) self._stack.pop() self._stack.append(elem) else: errs.append( pair_mismatch( line, char, f"one {raw_char} too much" ) ) self._stack.append(elem) for rest in self._stack: errs.append( Diagnostic( rest.range, f"unclosed {rest.elem.value.start}", severity=DiagnosticSeverity.Error, ) ) self._stack = [] return errs if __name__ == "__main__": example = Path(__file__).parent.parent.parent / "examples" / "example.il" t = TokenParser() t.prepare_content(example.read_text()) print(t._token_tree)