Source code for lvmnps.tools

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2023-11-22
# @Filename: tools.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

import asyncio

from typing import TYPE_CHECKING

import httpx
from pydantic import SecretStr
from pydantic.dataclasses import dataclass
from typing_extensions import Literal

from lvmnps import log


if TYPE_CHECKING:
    from lvmnps.nps.core import OutletModel


__all__ = [
    "APIClient",
    "normalise_outlet_name",
    "get_outlet_by_name",
    "get_outlet_by_id",
]


[docs] @dataclass class APIClient: """A wrapper around ``httpx.AsyncClient`` to yield a new client.""" base_url: str user: str password: SecretStr timeout: float = 3 auth_method: Literal["digest", "basic"] = "digest" def __post_init__(self): self.client: httpx.AsyncClient | None = None self.lock = asyncio.Lock() async def __aenter__(self): """Yields a new client.""" await self.lock.acquire() log.debug(f"Creating async client to {self.base_url!r} with digest.") if self.auth_method == "digest": auth = httpx.DigestAuth(self.user, self.password.get_secret_value()) elif self.auth_method == "basic": auth = (self.user, self.password.get_secret_value()) else: raise ValueError(f"Invalud authentication method {self.auth_method!r}.") self.client = httpx.AsyncClient( auth=auth, base_url=self.base_url, headers={}, timeout=self.timeout, ) return self.client async def __aexit__(self, exc_type, exc, tb): """Closes the client.""" self.lock.release() if self.client and not self.client.is_closed: log.debug("Closing async client.") await self.client.aclose()
[docs] def normalise_outlet_name(name: str): """Returns a normalised name for an outlet.""" return name.lower().replace(" ", "_")
[docs] def get_outlet_by_name(outlet_data: dict[str, OutletModel], name: str): """Gets an outlet from a list of outlets. Parameters ---------- outlet_data The mapping of outlet name to outlet model data. name The name of the outlet to retrieve. Returns ------- outlet The outlet matching the input name. Raises ------ ValueError If the outlet cannot be found. """ normalised_name = normalise_outlet_name(name) if normalised_name in outlet_data: return outlet_data[normalised_name] raise ValueError(f"Cannot find outlet with name {name!r}.")
[docs] def get_outlet_by_id(outlet_data: dict[str, OutletModel], id: int): """Gets an outlet by id. Parameters ---------- outlet_data The mapping of outlet name to outlet model data. id The id of the outlet to retrieve. Returns ------- outlet The outlet matching the id. Raises ------ ValueError If the outlet cannot be found. """ for outlet in outlet_data.values(): if outlet.id == id: return outlet raise ValueError(f"Cannot find outlet with id {id!r}.")