Skip to content

Commit

Permalink
Record issue datetime for keys (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianlini authored Dec 3, 2023
1 parent b36ab3f commit ab5bfb2
Show file tree
Hide file tree
Showing 14 changed files with 455 additions and 110 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Follow the instruction to create the config. Example output:
Interface name for WireGuard [wg0]:
Interface.ListenPort of the relay server [51820]:
Interface.Address of the relay server [192.168.10.1/24]:
The default endpoint in clients' Peer.Endpoint configs (e.g., example.com:51280): example.com
The default endpoint in clients' Peer.Endpoint configs (e.g., example.com:51820): example.com
If you want to allow the clients to access the internet via the relay server, you must provide the interface name you want to forward the internet traffic to. It's usually eth0 or wlan0. You can check it by executing `ip addr`. If you provide an interface name {interface}, the following rules will be added:
- iptables -A FORWARD -i %i -o {interface} -j ACCEPT
- iptables -A FORWARD -i {interface} -o %i -j ACCEPT
Expand Down
2 changes: 2 additions & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"python"
],
"words": [
"addoption",
"bistiming",
"capsys",
"distro",
"genkey",
"genpsk",
Expand Down
190 changes: 143 additions & 47 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
python = "^3.10"
click = "^8.1.3"
"ruamel.yaml" = "^0.17.21"
pydantic = "^1.9.2"
pydantic = "^2.0.0"
qrcode = "^7.3.1"

[tool.poetry.group.linter.dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/wg_wizard/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def main():
@option(
"--default-endpoint",
"-e",
prompt="The default endpoint in clients' Peer.Endpoint configs (e.g., example.com:51280)",
prompt="The default endpoint in clients' Peer.Endpoint configs (e.g., example.com:51820)",
help="""
The default endpoint in clients' Peer.Endpoint configs (e.g., example.com:51280).
The default endpoint in clients' Peer.Endpoint configs (e.g., example.com:51820).
If the port is not provided, it will be added automatically according to the listen_port.
The endpoint can also be overridden in peer configs.
""",
Expand Down
81 changes: 53 additions & 28 deletions src/wg_wizard/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from ipaddress import ip_interface
from pathlib import Path
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
import json
import logging

Expand All @@ -9,9 +10,10 @@
Field,
IPvAnyInterface,
IPvAnyAddress,
constr,
StringConstraints,
SecretStr,
PrivateAttr,
field_serializer,
)
from ruamel.yaml import YAML

Expand All @@ -30,40 +32,45 @@


class WgWizardPeerConfig(StrictModel):
listen_port: Optional[int]
fw_mark: Optional[Literal["off"] | int]
addresses: list[IPvAnyInterface] = Field(min_items=1)
listen_port: Optional[int] = None
fw_mark: Optional[Literal["off"] | int] = None
addresses: list[IPvAnyInterface] = Field(min_length=1)
dns_addresses: list[IPvAnyAddress] = Field(default_factory=list)
mtu: Optional[int]
table: Optional[str]
mtu: Optional[int] = None
table: Optional[str] = None
pre_up: list[str] = Field(default_factory=list)
post_up: list[str] = Field(default_factory=list)
pre_down: list[str] = Field(default_factory=list)
post_down: list[str] = Field(default_factory=list)
server_allowed_ips: list[IPvAnyInterface] = Field(min_items=1)
server_endpoint: Optional[constr(regex=r".+:\d+")] # noqa: F722
server_persistent_keepalive: Optional[Literal["off"] | int]
client_allowed_ips: list[IPvAnyInterface] = Field(min_items=1)
client_endpoint: Optional[constr(regex=r".+:\d+")] # noqa: F722
client_persistent_keepalive: Optional[Literal["off"] | int]
server_allowed_ips: list[IPvAnyInterface] = Field(min_length=1)
server_endpoint: Optional[
Annotated[str, StringConstraints(pattern=r".+:\d+")]
] = None
server_persistent_keepalive: Optional[Literal["off"] | int] = None
client_allowed_ips: list[IPvAnyInterface] = Field(min_length=1)
client_endpoint: Optional[
Annotated[str, StringConstraints(pattern=r".+:\d+")]
] = None
client_persistent_keepalive: Optional[Literal["off"] | int] = None


class WgWizardConfig(StrictModel):
name: constr(regex=r"[a-zA-Z0-9_=+.-]{1,15}") # noqa: F722
name: Annotated[str, StringConstraints(pattern=r"[a-zA-Z0-9_=+.-]{1,15}")]
listen_port: int
fw_mark: Optional[Literal["off"] | int]
addresses: list[IPvAnyInterface] = Field(min_items=1)
fw_mark: Optional[Literal["off"] | int] = None
addresses: list[IPvAnyInterface] = Field(min_length=1)
# TODO: test DNS DHCP
dns_addresses: list[IPvAnyAddress] = Field(default_factory=list)
mtu: Optional[int]
table: Optional[str]
mtu: Optional[int] = None
table: Optional[str] = None
pre_up: list[str] = Field(default_factory=list)
post_up: list[str] = Field(default_factory=list)
pre_down: list[str] = Field(default_factory=list)
post_down: list[str] = Field(default_factory=list)
default_endpoint: constr(regex=r".+:\d+") # noqa: F722
default_endpoint: Annotated[str, StringConstraints(pattern=r".+:\d+")]
peers: dict[
constr(regex=r"[a-zA-Z0-9_=+.-]+"), WgWizardPeerConfig # noqa: F722
Annotated[str, StringConstraints(pattern=r"[a-zA-Z0-9_=+.-]+")],
WgWizardPeerConfig,
] = Field(default_factory=dict)
_yaml: dict = PrivateAttr(default=None)

Expand All @@ -82,7 +89,7 @@ def dump(self, path: Path, overwrite=False):
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
if self._yaml is None:
yaml.dump(json.loads(self.json(exclude_unset=True)), path)
yaml.dump(json.loads(self.model_dump_json(exclude_unset=True)), path)
else:
yaml.dump(self._yaml, path)

Expand All @@ -104,13 +111,16 @@ def add_peer(self, name: str, peer_config: WgWizardPeerConfig):
if self._yaml is not None:
if "peers" not in self._yaml:
self._yaml["peers"] = {}
self._yaml["peers"][name] = json.loads(peer_config.json(exclude_unset=True))
self._yaml["peers"][name] = json.loads(
peer_config.model_dump_json(exclude_unset=True)
)


class WgWizardPeerSecret(StrictModel):
private_key: SecretStr
public_key: constr(min_length=1)
preshared_key: Optional[SecretStr]
public_key: str = Field(min_length=1)
preshared_key: Optional[SecretStr] = None
issued_on: Optional[datetime.datetime] = None

@classmethod
def generate(cls) -> "WgWizardPeerSecret":
Expand All @@ -120,6 +130,7 @@ def generate(cls) -> "WgWizardPeerSecret":
private_key=private_key,
public_key=public_key,
preshared_key=preshared_key,
issued_on=datetime.datetime.now(datetime.timezone.utc),
)

def check(self, name: str):
Expand All @@ -131,30 +142,40 @@ def check(self, name: str):
self.preshared_key.get_secret_value(), f"Peer {name} preshared_key"
)

@field_serializer("private_key", "preshared_key", when_used="json-unless-none")
def dump_secret(self, v):
return v.get_secret_value()


class WgWizardSecret(StrictModel):
private_key: SecretStr
public_key: constr(min_length=1)
public_key: str = Field(min_length=1)
issued_on: Optional[datetime.datetime] = None
peers: dict[str, WgWizardPeerSecret] = Field(default_factory=dict)

@classmethod
def generate(cls) -> "WgWizardSecret":
private_key, public_key = gen_key_pair()
return cls(private_key=private_key, public_key=public_key)
return cls(
private_key=private_key,
public_key=public_key,
issued_on=datetime.datetime.now(datetime.timezone.utc),
)

def regenerate_server_secret(self):
self.private_key, self.public_key = gen_key_pair()
self.issued_on = datetime.datetime.now(datetime.timezone.utc)

@classmethod
def from_file(cls, path: Path):
check_file_mode(path)
return cls.parse_file(path)
return cls.model_validate_json(path.read_text())

def dump(self, path: Path, overwrite=False):
path = path.resolve()
logger.info("Writing secret to %s", path)
ensure_file(path, mode=0o600, overwrite=overwrite)
path.write_text(self.json(indent=2))
path.write_text(self.model_dump_json(indent=2))

def generate_peer_secret(self, name: str) -> WgWizardPeerSecret:
peer_secret = WgWizardPeerSecret.generate()
Expand All @@ -166,6 +187,10 @@ def check(self):
for peer_name, peer_secret in self.peers.items():
peer_secret.check(peer_name)

@field_serializer("private_key", when_used="json")
def dump_secret(self, v):
return v.get_secret_value()


class WgWizard(StrictModel):
config: WgWizardConfig
Expand Down
23 changes: 11 additions & 12 deletions src/wg_wizard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import stat

import click
from pydantic import BaseModel, SecretStr, Extra
from pydantic import BaseModel, ConfigDict, SecretStr

from .wg import pubkey

Expand All @@ -17,19 +17,18 @@ def to_camel(string: str) -> str:


class StrictModel(BaseModel):
class Config:
extra = Extra.forbid
validate_assignment = True
validate_all = True
json_encoders = {
SecretStr: lambda v: v.get_secret_value(),
}
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
validate_default=True,
)


class StrictCamelModel(StrictModel):
class Config:
alias_generator = to_camel
allow_population_by_field_name = True
model_config = ConfigDict(
alias_generator=to_camel,
populate_by_name=True,
)


def ensure_file(path: Path, mode: int, overwrite=False):
Expand All @@ -55,7 +54,7 @@ def check_file_mode(path: Path):


def format_ini_lines(obj: BaseModel, exclude=None):
for field_name, field_config in obj.__fields__.items():
for field_name, field_config in obj.model_fields.items():
if exclude is not None and field_name in exclude:
continue
field_val = getattr(obj, field_name)
Expand Down
22 changes: 11 additions & 11 deletions src/wg_wizard/wg_quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@

class WgQuickInterfaceConfig(StrictCamelModel):
private_key: SecretStr
listen_port: Optional[int]
fw_mark: Optional[Literal["off"] | int]
address: list[IPvAnyInterface] = Field(min_items=1)
listen_port: Optional[int] = None
fw_mark: Optional[Literal["off"] | int] = None
address: list[IPvAnyInterface] = Field(min_length=1)
dns: list[IPvAnyAddress] = Field(alias="DNS")
mtu: Optional[int] = Field(alias="MTU")
table: Optional[str]
mtu: Optional[int] = Field(None, alias="MTU")
table: Optional[str] = None
pre_up: list[str]
post_up: list[str]
pre_down: list[str]
post_down: list[str]
save_config: Optional[bool]
save_config: Optional[bool] = None

def format_ini_lines(self) -> list[str]:
yield "[Interface]"
yield from format_ini_lines(self)


class WgQuickPeerConfig(StrictCamelModel):
comment: Optional[str]
comment: Optional[str] = None
public_key: str
preshared_key: Optional[SecretStr]
allowed_ips: list[IPvAnyInterface] = Field(alias="AllowedIPs", min_items=1)
endpoint: Optional[str]
persistent_keepalive: Optional[Literal["off"] | int]
preshared_key: Optional[SecretStr] = None
allowed_ips: list[IPvAnyInterface] = Field(alias="AllowedIPs", min_length=1)
endpoint: Optional[str] = None
persistent_keepalive: Optional[Literal["off"] | int] = None

def format_ini_lines(self) -> list[str]:
yield "[Peer]"
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest


def pytest_addoption(parser):
parser.addoption(
"--update-snapshot",
action="store_true",
default=False,
help="Whether to update the test snapshot.",
)


@pytest.fixture
def update_snapshot(request):
return request.config.getoption("--update-snapshot")
Loading

0 comments on commit ab5bfb2

Please sign in to comment.