Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth 2.0 Device Authorization Grant (RFC 8628) #15

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions proto/mairu.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ service Agent {
rpc InitiateOauthCode(InitiateOAuthCodeRequest) returns (InitiateOAuthCodeResponse);
rpc CompleteOauthCode(CompleteOAuthCodeRequest) returns (CompleteOAuthCodeResponse);

rpc InitiateOauthDeviceCode(InitiateOAuthDeviceCodeRequest) returns (InitiateOAuthDeviceCodeResponse);
rpc CompleteOauthDeviceCode(CompleteOAuthDeviceCodeRequest) returns (CompleteOAuthDeviceCodeResponse);

rpc RefreshAwsSsoClientRegistration(RefreshAwsSsoClientRegistrationRequest) returns (RefreshAwsSsoClientRegistrationResponse);
rpc InitiateAwsSsoDevice(InitiateAwsSsoDeviceRequest) returns (InitiateAwsSsoDeviceResponse);
rpc CompleteAwsSsoDevice(CompleteAwsSsoDeviceRequest) returns (CompleteAwsSsoDeviceResponse);
Expand Down Expand Up @@ -86,6 +89,24 @@ message CompleteOAuthCodeRequest {
message CompleteOAuthCodeResponse {
}

message InitiateOAuthDeviceCodeRequest {
string server_id = 1;
}
message InitiateOAuthDeviceCodeResponse {
string handle = 1;
string user_code = 2;
string verification_uri = 3;
string verification_uri_complete = 4;
google.protobuf.Timestamp expires_at = 5;
int32 interval = 6;
}

message CompleteOAuthDeviceCodeRequest {
string handle = 1;
}
message CompleteOAuthDeviceCodeResponse {
}

message RefreshAwsSsoClientRegistrationRequest {
string server_id = 1;
}
Expand Down
70 changes: 69 additions & 1 deletion src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,68 @@ impl crate::proto::agent_server::Agent for Agent {
Ok(tonic::Response::new(CompleteOAuthCodeResponse {}))
}

#[tracing::instrument(skip_all)]
async fn initiate_oauth_device_code(
&self,
request: tonic::Request<InitiateOAuthDeviceCodeRequest>,
) -> Result<tonic::Response<InitiateOAuthDeviceCodeResponse>, tonic::Status> {
let req = request.get_ref();

let server = match crate::config::Server::find_from_fs(&req.server_id).await {
Ok(server) => server,
Err(crate::Error::ConfigError(e)) => return Err(tonic::Status::internal(e)),
Err(crate::Error::UserError(e)) => return Err(tonic::Status::not_found(e)),
Err(e) => return Err(tonic::Status::internal(e.to_string())),
};

server.validate().map_err(|e| {
tonic::Status::failed_precondition(format!(
"Server '{}' has invalid configuration; {:}",
server.id(),
e,
))
})?;

let flow = crate::oauth_device_code::OAuthDeviceCodeFlow::initiate(&server)
.await
.map_err(|e| {
tracing::error!(err = ?e, "OAuthDeviceCodeFlow initiate failure");
tonic::Status::internal(e.to_string())
})?;

let response = (&flow).into();

tracing::debug!(flow = ?flow, "Initiated OAuth 2.0 Device Code flow");
self.auth_flow_manager
.store(crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow));

return Ok(tonic::Response::new(response));
}

#[tracing::instrument(skip_all)]
async fn complete_oauth_device_code(
&self,
request: tonic::Request<CompleteOAuthDeviceCodeRequest>,
) -> Result<tonic::Response<CompleteOAuthDeviceCodeResponse>, tonic::Status> {
let req = request.get_ref();
let Some(flow0) = self.auth_flow_manager.retrieve(&req.handle) else {
return Err(tonic::Status::not_found("flow handle not found"));
};
let completion = {
let crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow) = flow0.as_ref() else {
return Err(tonic::Status::invalid_argument(
"flow handle is not for the grant type",
));
};
tracing::trace!(flow = ?flow0.as_ref(), "Completing OAuth 2.0 Device Code Grant flow...");
flow.complete().await
};

self.accept_completed_auth_flow(flow0, completion)?;

Ok(tonic::Response::new(CompleteOAuthDeviceCodeResponse {}))
}

#[tracing::instrument(skip_all)]
async fn refresh_aws_sso_client_registration(
&self,
Expand Down Expand Up @@ -325,7 +387,13 @@ impl Agent {
) -> tonic::Result<()> {
let token = match completion {
Ok(t) => t,
Err(crate::Error::AuthNotReadyError) => {
Err(crate::Error::AuthNotReadyError { slow_down: true }) => {
tracing::debug!(flow = ?flow.as_ref(), "not yet ready, slow down");
return Err(tonic::Status::resource_exhausted(
"not yet ready, slow down".to_string(),
));
}
Err(crate::Error::AuthNotReadyError { slow_down: false }) => {
tracing::debug!(flow = ?flow.as_ref(), "not yet ready");
return Err(tonic::Status::failed_precondition(
"not yet ready".to_string(),
Expand Down
2 changes: 2 additions & 0 deletions src/auth_flow_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub const MAX_ITEMS: usize = 15;
pub enum AuthFlow {
Nop,
OAuthCode(crate::oauth_code::OAuthCodeFlow),
OAuthDeviceCode(crate::oauth_device_code::OAuthDeviceCodeFlow),
AwsSsoDevice(crate::oauth_awssso::AwsSsoDeviceFlow),
}

Expand All @@ -13,6 +14,7 @@ impl AuthFlow {
match self {
AuthFlow::Nop => "",
AuthFlow::OAuthCode(f) => &f.handle,
AuthFlow::OAuthDeviceCode(f) => &f.handle,
AuthFlow::AwsSsoDevice(f) => &f.handle,
}
}
Expand Down
68 changes: 65 additions & 3 deletions src/cmd/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ pub async fn login(
server.validate()?;

let oauth = server.oauth.as_ref().unwrap();
let oauth_grant_type = args
.oauth_grant_type
.unwrap_or_else(|| oauth.default_grant_type());
let oauth_grant_type = match args.oauth_grant_type {
Some(x) => Ok(x),
None => oauth.default_grant_type(),
}?;

tracing::debug!(oauth_grant_type = ?oauth_grant_type, server = ?server, "Using OAuth");

match oauth_grant_type {
crate::config::OAuthGrantType::Code => do_oauth_code(agent, server).await,
crate::config::OAuthGrantType::DeviceCode => do_oauth_device_code(agent, server).await,
crate::config::OAuthGrantType::AwsSso => do_awssso(agent, server).await,
}
}
Expand Down Expand Up @@ -89,6 +91,66 @@ pub async fn do_oauth_code(
Ok(())
}

pub async fn do_oauth_device_code(
agent: &mut crate::agent::AgentConn,
server: crate::config::Server,
) -> Result<(), anyhow::Error> {
server.try_oauth_device_code_grant()?;

let session = agent
.initiate_oauth_device_code(crate::proto::InitiateOAuthDeviceCodeRequest {
server_id: server.id().to_owned(),
})
.await?
.into_inner();
tracing::debug!(session = ?session, "Initiated flow");

let product = env!("CARGO_PKG_NAME");
let server_id = server.id();
let server_url = &server.url;
let user_code = &session.user_code;
let mut authorize_url = &session.verification_uri_complete;
if authorize_url.is_empty() {
authorize_url = &session.verification_uri;
}

crate::terminal::send(&indoc::formatdoc! {"
:: {product} :: Login to {server_id} ({server_url}) ::::::::
:: {product} ::
:: {product} :: Your Verification Code: {user_code}
:: {product} :: To authorize, visit: {authorize_url}
:: {product} ::
"})
.await;

let mut interval = session.interval as u64;
loop {
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
let completion = agent
.complete_oauth_device_code(crate::proto::CompleteOAuthDeviceCodeRequest {
handle: session.handle.clone(),
})
.await;

match completion {
Ok(_) => break,
Err(e) if e.code() == tonic::Code::ResourceExhausted => {
interval += 5;
tracing::debug!(interval = ?interval, "Received slow_down request");
}
Err(e) if e.code() == tonic::Code::FailedPrecondition => {
// continue
}
Err(e) => {
anyhow::bail!(e);
}
}
}

tracing::info!("Logged in");
Ok(())
}

pub async fn do_awssso(
agent: &mut crate::agent::AgentConn,
server: crate::config::Server,
Expand Down
54 changes: 40 additions & 14 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,24 @@ impl Server {
Ok((oauth, code_grant))
}

pub fn try_oauth_device_code_grant(
&self,
) -> crate::Result<(&ServerOAuth, &ServerDeviceCodeGrant)> {
let Some(oauth) = self.oauth.as_ref() else {
return Err(crate::Error::ConfigError(format!(
"Server '{}' is missing OAuth 2.0 client configuration",
self.id()
)));
};
match oauth.device_code_grant {
Some(ref grant) => Ok((oauth, grant)),
None => Err(crate::Error::ConfigError(format!(
"Server '{}' is missing OAuth 2.0 Device Code Grant configuration",
self.id()
))),
}
}

pub fn try_oauth_awssso(&self) -> crate::Result<&ServerOAuth> {
if self.aws_sso.is_none() {
return Err(crate::Error::ConfigError(format!(
Expand Down Expand Up @@ -307,6 +325,7 @@ impl TryFrom<crate::proto::GetServerResponse> for Server {
#[serde(rename_all = "snake_case")]
pub enum OAuthGrantType {
Code,
DeviceCode,
AwsSso,
}

Expand All @@ -315,6 +334,7 @@ impl std::str::FromStr for OAuthGrantType {
fn from_str(s: &str) -> Result<OAuthGrantType, crate::Error> {
match s {
"code" => Ok(OAuthGrantType::Code),
"device_code" => Ok(OAuthGrantType::DeviceCode),
"aws_sso" => Ok(OAuthGrantType::AwsSso),
_ => Err(crate::Error::UserError(
"unknown oauth_grant_type".to_owned(),
Expand All @@ -333,7 +353,7 @@ pub struct ServerOAuth {
pub scope: Vec<String>,
default_grant_type: Option<OAuthGrantType>,
pub code_grant: Option<ServerCodeGrant>,
pub device_grant: Option<ServerDeviceGrant>,
pub device_code_grant: Option<ServerDeviceCodeGrant>,

/// Expiration time of dynamically registered OAuth2 client registration, such as AWS SSO
/// clients.
Expand All @@ -346,14 +366,15 @@ fn default_oauth_scope() -> Vec<String> {

impl ServerOAuth {
pub fn validate(&self) -> Result<(), crate::Error> {
if self.code_grant.is_none() && self.device_grant.is_none() {
if self.code_grant.is_none() && self.device_code_grant.is_none() {
return Err(crate::Error::ConfigError(
"Either oauth.code_grant or oauth.device_grant must be provided, but absent"
"Either oauth.code_grant or oauth.device_code_grant must be provided, but absent"
.to_owned(),
));
}
if match self.default_grant_type {
None => false,
Some(OAuthGrantType::DeviceCode) => self.device_code_grant.is_none(),
Some(OAuthGrantType::Code) => self.code_grant.is_none(),
Some(OAuthGrantType::AwsSso) => false,
} {
Expand All @@ -364,16 +385,21 @@ impl ServerOAuth {
Ok(())
}

pub fn default_grant_type(&self) -> OAuthGrantType {
self.default_grant_type.unwrap_or_else(|| {
if self.code_grant.is_some() {
return OAuthGrantType::Code;
}
if self.device_grant.is_some() {
// TODO: implement
pub fn default_grant_type(&self) -> Result<OAuthGrantType, crate::Error> {
match self.default_grant_type {
Some(x) => Ok(x),
None => {
if self.code_grant.is_some() {
return Ok(OAuthGrantType::Code);
}
if self.device_code_grant.is_some() {
return Ok(OAuthGrantType::DeviceCode);
}
Err(crate::Error::ConfigError(
"cannot determine default grant_type".to_string(),
))
}
unreachable!();
})
}
}
}

Expand All @@ -384,7 +410,7 @@ pub struct ServerCodeGrant {
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct ServerDeviceGrant {
pub struct ServerDeviceCodeGrant {
pub device_authorization_endpoint: Option<url::Url>,
}

Expand Down Expand Up @@ -435,7 +461,7 @@ impl AwsSsoClientRegistrationCache {
token_endpoint: None,
scope: sso.scope.clone(),
code_grant: None,
device_grant: None,
device_code_grant: None,
client_expires_at: Some(self.expires_at),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub enum Error {
UrlParseError(#[from] url::ParseError),

#[error("AuthNotReadyError: flow not yet ready")]
AuthNotReadyError,
AuthNotReadyError { slow_down: bool },

#[error(transparent)]
OAuth2RequestTokenError(
Expand Down
Loading