Skip to content

Commit

Permalink
Auth username password (#34)
Browse files Browse the repository at this point in the history
* Auth username password
  • Loading branch information
aamalev authored Nov 12, 2023
1 parent f272a78 commit c76d635
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 72 deletions.
6 changes: 6 additions & 0 deletions redis_rs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def create_client(
*args: str,
max_size: Optional[int] = None,
cluster: Optional[bool] = None,
username: Optional[str] = None,
password: Optional[str] = None,
db: Optional[int] = None,
client_id: Optional[str] = None,
features: Optional[List[str]] = None,
) -> Client:
Expand All @@ -27,6 +30,9 @@ def create_client(
*args,
max_size=max_size,
cluster=cluster,
username=username,
password=password,
db=db,
client_id=client_id,
features=features,
)
3 changes: 3 additions & 0 deletions redis_rs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def create_client(
*args: str,
max_size: Optional[int] = None,
cluster: Optional[bool] = None,
username: Optional[str] = None,
password: Optional[str] = None,
db: Optional[int] = None,
client_id: Optional[str] = None,
features: Optional[List[str]] = None,
) -> Client: ...
5 changes: 5 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl Client {
let mut status = self.cr.status()?;
let is_closed = status.remove("closed");
let is_cluster = status.remove("cluster");
let is_auth = status.remove("auth");
let result = PyDict::new(py);
for (k, v) in status.into_iter() {
let value = types::to_object(py, v, "utf-8");
Expand All @@ -43,6 +44,10 @@ impl Client {
let is_closed = c == 1;
result.set_item("closed", is_closed.to_object(py))?;
}
if let Some(redis::Value::Int(c)) = is_auth {
let is_auth = c == 1;
result.set_item("auth", is_auth.to_object(py))?;
}
Ok(result.to_object(py))
}

Expand Down
9 changes: 6 additions & 3 deletions src/cluster_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
pool::{Connection, Pool},
};
use async_trait::async_trait;
use redis::{aio::ConnectionLike, cluster::ClusterClient, Cmd};
use redis::{aio::ConnectionLike, cluster::ClusterClient, Cmd, IntoConnectionInfo};
use tokio::sync::Semaphore;

pub struct Cluster {
Expand All @@ -14,8 +14,11 @@ pub struct Cluster {
}

impl Cluster {
pub async fn new(initial_nodes: Vec<String>, max_size: u32) -> Result<Self, error::RedisError> {
let client = ClusterClient::new(initial_nodes).unwrap();
pub async fn new<T>(initial_nodes: Vec<T>, max_size: u32) -> Result<Self, error::RedisError>
where
T: IntoConnectionInfo,
{
let client = ClusterClient::new(initial_nodes)?;
let semaphore = Semaphore::new(max_size as usize);
let connection = client.get_async_connection().await?;
Ok(Self {
Expand Down
53 changes: 48 additions & 5 deletions src/cluster_bb8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,59 @@ use crate::{
pool::{Connection, Pool},
};
use async_trait::async_trait;
use bb8_redis_cluster::RedisConnectionManager;
use redis::{aio::ConnectionLike, Cmd};
use redis::{
aio::ConnectionLike, cluster::ClusterClient, cluster_async::ClusterConnection, Cmd, ErrorKind,
IntoConnectionInfo, RedisError,
};

pub struct ClusterManager {
pub(crate) client: ClusterClient,
}

impl ClusterManager {
pub fn new<T>(initial_nodes: Vec<T>) -> Result<Self, RedisError>
where
T: IntoConnectionInfo,
{
let client = ClusterClient::new(initial_nodes)?;
Ok(Self { client })
}
}

#[async_trait]
impl bb8::ManageConnection for ClusterManager {
type Connection = ClusterConnection;
type Error = RedisError;

async fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.client.get_async_connection().await
}

async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
let pong: String = redis::cmd("PING").query_async(conn).await?;
match pong.as_str() {
"PONG" => Ok(()),
_ => Err((ErrorKind::ResponseError, "ping request").into()),
}
}

fn has_broken(&self, _: &mut Self::Connection) -> bool {
false
}
}

type Manager = ClusterManager;

pub struct BB8Cluster {
pool: bb8::Pool<RedisConnectionManager>,
pool: bb8::Pool<Manager>,
}

impl BB8Cluster {
pub async fn new(initial_nodes: Vec<String>, max_size: u32) -> Self {
let manager = RedisConnectionManager::new(initial_nodes).unwrap();
pub async fn new<T>(initial_nodes: Vec<T>, max_size: u32) -> Self
where
T: IntoConnectionInfo,
{
let manager = Manager::new(initial_nodes).unwrap();
let pool = bb8::Pool::builder()
.max_size(max_size)
.build(manager)
Expand Down
16 changes: 12 additions & 4 deletions src/cluster_deadpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use async_trait::async_trait;
use deadpool_redis_cluster::{Config, PoolError, Runtime};
use redis::{aio::ConnectionLike, Cmd};
use redis::{aio::ConnectionLike, Cmd, IntoConnectionInfo};
use std::collections::HashMap;

pub struct DeadPoolCluster {
Expand All @@ -21,13 +21,21 @@ impl From<PoolError> for error::RedisError {
}

impl DeadPoolCluster {
pub fn new(initial_nodes: Vec<String>, max_size: u32) -> Self {
let cfg = Config::from_urls(initial_nodes);
pub fn new<T>(initial_nodes: Vec<T>, max_size: u32) -> Result<Self, error::RedisError>
where
T: IntoConnectionInfo,
{
let mut urls = vec![];
for i in initial_nodes.into_iter() {
let url = i.into_connection_info()?;
urls.push(url.addr.to_string());
}
let cfg = Config::from_urls(urls);
let pool = cfg
.create_pool(Some(Runtime::Tokio1))
.expect("Error with redis pool");
pool.resize(max_size as usize);
Self { pool }
Ok(Self { pool })
}
}

Expand Down
51 changes: 37 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use pyo3::prelude::*;
use redis::IntoConnectionInfo;
mod client;
mod client_result;
mod client_result_async;
Expand All @@ -17,27 +18,49 @@ mod single_node;
mod types;

#[pyfunction]
#[pyo3(signature = (*initial_nodes, max_size=None, cluster=None, client_id=None, features=None))]
#[pyo3(signature = (
*initial_nodes,
max_size=None,
cluster=None,
username = None,
password = None,
db = None,
client_id=None,
features=None,
))]
#[allow(clippy::too_many_arguments)]
fn create_client(
initial_nodes: Vec<String>,
max_size: Option<u32>,
cluster: Option<bool>,
username: Option<String>,
password: Option<String>,
db: Option<i64>,
client_id: Option<String>,
features: Option<Vec<String>>,
) -> PyResult<client::Client> {
let is_cluster = match cluster {
None => initial_nodes.len() > 1,
Some(c) => c,
};
let mut cm = if is_cluster {
pool_manager::PoolManager::new_cluster(initial_nodes)
} else {
let addr = initial_nodes
.get(0)
.map(String::as_str)
.unwrap_or("redis://localhost:6379");
pool_manager::PoolManager::new(addr)
};
let mut nodes = initial_nodes.clone();
if nodes.is_empty() {
nodes.push("redis://localhost:6379".to_string());
}
let mut infos = vec![];
for i in nodes.into_iter() {
let mut info = i.into_connection_info().map_err(error::RedisError::from)?;
if password.is_some() {
info.redis.password = password.clone();
}
if username.is_some() {
info.redis.username = username.clone();
}
if let Some(db) = db {
info.redis.db = db;
}
infos.push(info);
}

let mut cm = pool_manager::PoolManager::new(infos)?;
cm.is_cluster = cluster;

if let Some(features) = features {
cm.features = features
.into_iter()
Expand Down
83 changes: 43 additions & 40 deletions src/pool_manager.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::HashMap, sync::Arc};

use redis::{Cmd, FromRedisValue, IntoConnectionInfo};
use redis::{Cmd, ConnectionInfo, FromRedisValue};

use crate::{
client::Client,
Expand Down Expand Up @@ -31,60 +31,55 @@ impl From<PoolManager> for Client {
}

pub struct PoolManager {
pub(crate) is_cluster: bool,
pub(crate) initial_nodes: Vec<String>,
pub(crate) is_cluster: Option<bool>,
pub(crate) initial_nodes: Vec<ConnectionInfo>,
pub(crate) max_size: u32,
pub(crate) pool: Box<dyn Pool + Send + Sync>,
pub(crate) client_id: String,
pub(crate) features: Vec<types::Feature>,
}

impl PoolManager {
pub fn new_cluster(initial_nodes: Vec<String>) -> Self {
Self {
pub fn new(initial_nodes: Vec<ConnectionInfo>) -> Result<Self, error::RedisError> {
Ok(Self {
initial_nodes,
is_cluster: true,
max_size: 10,
pool: Box::new(ClosedPool),
client_id: String::default(),
features: vec![],
}
}

pub fn new(addr: &str) -> Self {
Self {
initial_nodes: vec![addr.to_string()],
is_cluster: false,
is_cluster: Some(false),
max_size: 10,
pool: Box::new(ClosedPool),
client_id: String::default(),
features: vec![],
}
})
}

pub async fn init(&mut self) -> Result<&Self, error::RedisError> {
let nodes = self.initial_nodes.clone();
let mut nodes = self.initial_nodes.clone();
let ms = self.max_size;
let is_cluster = self.is_cluster;
if is_cluster {
self.pool = match self.features.as_slice() {
[types::Feature::BB8, ..] => Box::new(BB8Cluster::new(nodes, ms).await),
[types::Feature::DeadPool, ..] => Box::new(DeadPoolCluster::new(nodes, ms)),
[types::Feature::Shards, ..] => {
Box::new(AsyncShards::new(nodes, ms, Some(true)).await?)
}
_ => Box::new(Cluster::new(nodes, ms).await?),
};
} else {
let info = nodes.clone().remove(0).into_connection_info()?;
self.pool = match self.features.as_slice() {
[types::Feature::BB8, ..] => Box::new(BB8Pool::new(info, ms).await?),
[types::Feature::DeadPool, ..] => Box::new(DeadPool::new(info, ms).await?),
[types::Feature::Shards, ..] => {
Box::new(AsyncShards::new(nodes, ms, Some(false)).await?)
}
_ => Box::new(Node::new(info, ms).await?),
};
match self.is_cluster {
None => {
self.pool = Box::new(AsyncShards::new(nodes, ms, self.is_cluster).await?);
}
Some(true) => {
self.pool = match self.features.as_slice() {
[types::Feature::BB8, ..] => Box::new(BB8Cluster::new(nodes, ms).await),
[types::Feature::DeadPool, ..] => Box::new(DeadPoolCluster::new(nodes, ms)?),
[types::Feature::Shards, ..] => {
Box::new(AsyncShards::new(nodes, ms, Some(true)).await?)
}
_ => Box::new(Cluster::new(nodes, ms).await?),
};
}
Some(false) => {
self.pool = match self.features.as_slice() {
[types::Feature::BB8, ..] => Box::new(BB8Pool::new(nodes.remove(0), ms).await?),
[types::Feature::DeadPool, ..] => {
Box::new(DeadPool::new(nodes.remove(0), ms).await?)
}
[types::Feature::Shards, ..] => {
Box::new(AsyncShards::new(nodes, ms, Some(false)).await?)
}
_ => Box::new(Node::new(nodes.remove(0), ms).await?),
};
}
};
Ok(self)
}
Expand All @@ -99,7 +94,15 @@ impl PoolManager {
let initial_nodes = self
.initial_nodes
.iter()
.map(|s| redis::Value::Data(s.as_bytes().to_vec()))
.map(|s| {
if let Some(username) = s.redis.username.clone() {
result.insert("username", redis::Value::Data(username.as_bytes().to_vec()));
}
if s.redis.password.is_some() {
result.insert("auth", redis::Value::Int(1));
}
redis::Value::Data(s.addr.to_string().as_bytes().to_vec())
})
.collect();
result.insert("initial_nodes", redis::Value::Bulk(initial_nodes));
result.insert("max_size", redis::Value::Int(self.max_size as i64));
Expand Down
9 changes: 6 additions & 3 deletions src/shards_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ pub struct AsyncShards {
}

impl AsyncShards {
pub async fn new(
nodes: Vec<String>,
pub async fn new<T>(
nodes: Vec<T>,
max_size: u32,
is_cluster: Option<bool>,
) -> RedisResult<AsyncShards> {
) -> RedisResult<AsyncShards>
where
T: IntoConnectionInfo,
{
let mut result = Self {
max_size,
..Default::default()
Expand Down
12 changes: 9 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@ async def get_redis_version(nodes: list) -> str:


@pytest.fixture
async def async_client():
async with redis_rs.create_client(
def client_factory():
return lambda **kwargs: redis_rs.create_client(
*NODES,
cluster=IS_CLUSTER,
) as c:
**kwargs,
)


@pytest.fixture
async def async_client(client_factory):
async with client_factory() as c:
yield c


Expand Down
Loading

0 comments on commit c76d635

Please sign in to comment.