Browse Source

initial commit

boyska 10 months ago
commit
9a613ed898
3 changed files with 367 additions and 0 deletions
  1. 82 0
      decoder.py
  2. 172 0
      read.py
  3. 113 0
      webserver.py

+ 82 - 0
decoder.py

@@ -0,0 +1,82 @@
+"""
+This module
+"""
+
+from typing import Optional
+from dataclasses import dataclass, asdict
+
+
+@dataclass
+class Message:
+    orig_line: str
+
+    def asdict(self):
+        return asdict(self)
+
+
+@dataclass
+class LogMessage(Message):
+    level: int
+    message: str
+
+
+@dataclass
+class DumpMessage(Message):
+    variables: dict[str, int]
+
+
+class Decoder:
+    """
+    >>> d = Decoder()
+    >>> sorted(d.decode(b'DMP V=1 G=200 FREQ=1234567').variables.items())
+    [('FREQ', 1234567), ('G', 200), ('V', 1)]
+    >>> d.decode(b'LOG D ciao').level
+    1
+    >>> d.decode(b'LOG D ciao').message
+    'ciao'
+    """
+
+    log_levels = {
+        "D": 1,
+        "I": 2,
+        "W": 3,
+        "E": 4,
+        "C": 5,
+    }
+
+    def __init__(self):
+        pass
+
+    def decode_level(self, description: str) -> int:
+        return self.log_levels[description]
+
+    def decode_value(self, key: str, value: str) -> Optional[dict]:
+        return int(value, base=10)
+
+    def decode_log(self, line: str) -> LogMessage:
+        level, message = line.split(" ", 1)
+        level = self.decode_level(level)
+        return LogMessage(level=level, message=message, orig_line=line)
+
+    def decode_dump(self, line: str) -> DumpMessage:
+        variables = {}
+        variables_settings = line.split()
+        for varset in variables_settings:
+            key, val = varset.split("=", 1)
+            val = self.decode_value(key, val)
+            variables[key] = val
+        return DumpMessage(variables=variables, orig_line=line)
+
+    def decode(self, line: bytes) -> Optional[Message]:
+        """
+        Returns None if the line is not meant to be handled by this decoder.
+
+        Raise if meant for us but invalid.
+        """
+        line = line.decode("ascii")
+
+        if line.startswith("LOG "):
+            return self.decode_log(line.removeprefix("LOG "))
+        if line.startswith("DMP "):
+            return self.decode_dump(line.removeprefix("DMP "))
+        return None

+ 172 - 0
read.py

@@ -0,0 +1,172 @@
+"""
+This module connects to serial port and exposes the results in stdout.
+"""
+
+import decoder
+import json
+from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
+from typing import Optional
+import atexit
+import time
+import logging
+import sys
+import multiprocessing
+
+import serial
+import requests
+
+log = logging.getLogger("inaria")
+
+
+def read_auth(buf) -> tuple[str, str]:
+    username = buf.readline()
+    password = buf.readline()
+    buf.close()
+    return (username.rstrip(), password.rstrip("\n"))
+
+
+class MessageForwarder:
+    def __init__(self):
+        self.base_url: Optional[str] = None
+        self.auth: Optional[tuple[str, str]] = None
+
+    def initialize_from_args(self, args):
+        pass
+
+    @property
+    def request_params(self) -> dict:
+        r = {}
+        if self.auth is not None:
+            r["auth"] = self.auth
+        return r
+
+    def send_log(self, message: decoder.LogMessage):
+        requests.post(
+            f"{self.base_url}/messages", json=message.asdict(), **self.request_params
+        )
+
+    def send_dump(self, message: decoder.DumpMessage):
+        requests.post(
+            f"{self.base_url}/variables", json=message.asdict(), **self.request_params
+        )
+
+    def send_message(self, message: decoder.Message):
+        if not self.base_url:
+            return
+
+        if isinstance(message, decoder.LogMessage):
+            self.send_log(message)
+        elif isinstance(message, decoder.DumpMessage):
+            self.send_dump(message)
+
+
+def get_next_message(serial) -> Optional[bytes]:
+    """
+    >>> from io import BytesIO
+    >>> msg = 'foo\\x01LOG D ciao\\n'
+    >>> get_next_message(BytesIO(msg.encode('ascii'))).rstrip().decode('ascii')
+    'LOG D ciao'
+    >>> msg = 'foo\\nasd\\x01LOG D ciao\\n'
+    >>> get_next_message(BytesIO(msg.encode('ascii'))).rstrip().decode('ascii')
+    'LOG D ciao'
+    """
+    while True:
+        c = serial.read(1)
+        log.info("%r", c)
+        if not c:
+            return None
+        if ord(c) == 1:
+            break
+    return serial.readline()  # read a '\n' terminated line
+
+
+def loop(serial, forwarder: MessageForwarder, args):
+    dec = decoder.Decoder()
+    while True:
+        line = get_next_message(serial)
+        try:
+            message = dec.decode(line)
+        except Exception:
+            continue
+
+        if message is None:
+            continue
+
+        obj = (str(type(message)), message.asdict())
+        print(json.dumps(obj))
+        multiprocessing.Process(target=forwarder.send_message, args=(message,)).start()
+
+
+def close_all(serial):
+    serial.close()
+
+
+def get_parser() -> ArgumentParser:
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    serial_parser = parser.add_argument_group("serial options")
+    serial_parser.add_argument(
+        "--device",
+        default="hwgrep://ttyUSB",
+        help="Device path, or URL as in https://pyserial_parser.readthedocs.io/en/latest/url_handlers.html",
+    )
+    serial_parser.add_argument("--baudrate", type=int, default=115200)
+    serial_parser.add_argument(
+        "--wait",
+        action="store_true",
+        default=False,
+        help="Wait until serial is found, and retries upon failures",
+    )
+
+    http_parser = parser.add_argument_group("http options")
+    http_parser.add_argument(
+        "--http-endpoint", metavar="URL", help="sth like http://127.0.0.1:8000/"
+    )
+    http_parser.add_argument(
+        "--http-auth-file",
+        type=open,
+        metavar="FILE",
+        help="Path to a file with two lines: first is username, second is password",
+    )
+
+    parser.add_argument("--verbose", "-v", action="store_true", default=False)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+
+    logging.basicConfig(level=logging.INFO)
+    log.setLevel(logging.INFO if args.verbose else logging.WARN)
+
+    log.info("Connecting...")
+    forwarder = MessageForwarder()
+    if args.http_endpoint:
+        forwarder.base_url = args.http_endpoint
+        if args.http_auth_file is not None:
+            forwarder.auth = read_auth(args.http_auth_file)
+    while True:
+        try:
+            s = serial.serial_for_url(args.device, do_not_open=True)
+            s.baudrate = args.baudrate
+            s.open()
+        except Exception as exc:
+            if not args.wait:
+                log.info("Cannot connect: %s", exc)
+                sys.exit(1)
+            log.info("Cannot connect, will retry...")
+            time.sleep(1)
+            continue
+
+        log.info("Connected!")
+
+        atexit.register(close_all, s)
+        try:
+            loop(s, forwarder, args)
+        except serial.serialutil.SerialException:
+            if not args.wait:
+                sys.exit(1)
+
+
+if __name__ == "__main__":
+    main()

+ 113 - 0
webserver.py

@@ -0,0 +1,113 @@
+import secrets
+from uuid import UUID, uuid4
+import time
+from typing import Optional, Callable, Any
+
+from pydantic import BaseModel, BaseSettings, Field
+from fastapi import FastAPI, Depends, HTTPException, status
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from fastapi.middleware.cors import CORSMiddleware
+
+
+class VariableModel(BaseModel):
+    key: str
+    value: int
+
+
+class AllVariablesModel(BaseModel):
+    variables: dict[str, int]
+
+
+class MessageModel(BaseModel):
+    message: str
+    level: int = 0
+    id: UUID = Field(default_factory=uuid4)
+    timestamp: float = Field(default_factory=time.time)
+
+
+class Settings(BaseSettings):
+    app_name: str = "Squeow"
+    serial_password: str = "hackme"
+    variables: dict[str, int] = {}
+    messages_length: int = 10
+    messages: list[MessageModel] = []
+
+    class Config:
+        env_file = "pizzicore.env"
+
+    def push_message(self, message):
+        self.messages.append(message)
+        if len(self.messages) > self.messages_length:
+            self.messages.pop(0)
+
+
+app = FastAPI()
+settings = Settings()
+security = HTTPBasic()
+
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
+    correct_username = secrets.compare_digest(credentials.username, "serial")
+    correct_password = secrets.compare_digest(
+        credentials.password, settings.serial_password
+    )
+    if not (correct_username and correct_password):
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail="Incorrect username or password",
+            headers={"WWW-Authenticate": "Basic"},
+        )
+    return "serial"
+
+
+@app.post("/variables")
+async def update_all_variables(
+    variables: AllVariablesModel, role: str = Depends(get_current_role)
+) -> None:
+    if role != "serial":
+        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
+    settings.variables.update(variables.variables)
+    return
+
+
+@app.post("/messages")
+async def push_message(
+    message: MessageModel, role: str = Depends(get_current_role)
+) -> None:
+    if role != "serial":
+        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
+    settings.push_message(message)
+    return
+
+
+@app.get("/variables")
+async def get_all_variables() -> AllVariablesModel:
+    return settings.variables
+
+
+def first_matching(lst: list, condition: Callable[[Any], bool]) -> int:
+    """return the index of the first item that matches condition"""
+    for i, elem in enumerate(lst):
+        if condition(elem):
+            return i
+    return None
+
+
+@app.get("/messages")
+async def get_all_messages(from_id: Optional[UUID] = None) -> list[MessageModel]:
+    messages = settings.messages
+    if from_id is not None:
+        match = first_matching(messages, lambda x: x.id == from_id)
+        # if match is not found, we assume that the referred id is very old, so all messages are relevant
+        if match is not None:
+            messages = messages[match:]
+    return messages