diff --git a/proto/mairu.proto b/proto/mairu.proto index eaa033a..e7a69e1 100644 --- a/proto/mairu.proto +++ b/proto/mairu.proto @@ -28,6 +28,9 @@ service Agent { rpc InitiateOauthCode(InitiateOAuthCodeRequest) returns (InitiateOAuthCodeResponse); rpc CompleteOauthCode(CompleteOAuthCodeRequest) returns (CompleteOAuthCodeResponse); + rpc InitiateOauthDeviceCode(InitiateOAuthDeviceCodeRequest) returns (InitiateOAuthDeviceCodeResponse); + rpc CompleteOauthDeviceCode(CompleteOAuthDeviceCodeRequest) returns (CompleteOAuthDeviceCodeResponse); + rpc RefreshAwsSsoClientRegistration(RefreshAwsSsoClientRegistrationRequest) returns (RefreshAwsSsoClientRegistrationResponse); rpc InitiateAwsSsoDevice(InitiateAwsSsoDeviceRequest) returns (InitiateAwsSsoDeviceResponse); rpc CompleteAwsSsoDevice(CompleteAwsSsoDeviceRequest) returns (CompleteAwsSsoDeviceResponse); @@ -86,6 +89,24 @@ message CompleteOAuthCodeRequest { message CompleteOAuthCodeResponse { } +message InitiateOAuthDeviceCodeRequest { + string server_id = 1; +} +message InitiateOAuthDeviceCodeResponse { + string handle = 1; + string user_code = 2; + string verification_uri = 3; + string verification_uri_complete = 4; + google.protobuf.Timestamp expires_at = 5; + int32 interval = 6; +} + +message CompleteOAuthDeviceCodeRequest { + string handle = 1; +} +message CompleteOAuthDeviceCodeResponse { +} + message RefreshAwsSsoClientRegistrationRequest { string server_id = 1; } diff --git a/src/agent.rs b/src/agent.rs index ff00d98..da46bfb 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -210,6 +210,68 @@ impl crate::proto::agent_server::Agent for Agent { Ok(tonic::Response::new(CompleteOAuthCodeResponse {})) } + #[tracing::instrument(skip_all)] + async fn initiate_oauth_device_code( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let req = request.get_ref(); + + let server = match crate::config::Server::find_from_fs(&req.server_id).await { + Ok(server) => server, + Err(crate::Error::ConfigError(e)) => return Err(tonic::Status::internal(e)), + Err(crate::Error::UserError(e)) => return Err(tonic::Status::not_found(e)), + Err(e) => return Err(tonic::Status::internal(e.to_string())), + }; + + server.validate().map_err(|e| { + tonic::Status::failed_precondition(format!( + "Server '{}' has invalid configuration; {:}", + server.id(), + e, + )) + })?; + + let flow = crate::oauth_device_code::OAuthDeviceCodeFlow::initiate(&server) + .await + .map_err(|e| { + tracing::error!(err = ?e, "OAuthDeviceCodeFlow initiate failure"); + tonic::Status::internal(e.to_string()) + })?; + + let response = (&flow).into(); + + tracing::debug!(flow = ?flow, "Initiated OAuth 2.0 Device Code flow"); + self.auth_flow_manager + .store(crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow)); + + return Ok(tonic::Response::new(response)); + } + + #[tracing::instrument(skip_all)] + async fn complete_oauth_device_code( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let req = request.get_ref(); + let Some(flow0) = self.auth_flow_manager.retrieve(&req.handle) else { + return Err(tonic::Status::not_found("flow handle not found")); + }; + let completion = { + let crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow) = flow0.as_ref() else { + return Err(tonic::Status::invalid_argument( + "flow handle is not for the grant type", + )); + }; + tracing::trace!(flow = ?flow0.as_ref(), "Completing OAuth 2.0 Device Code Grant flow..."); + flow.complete().await + }; + + self.accept_completed_auth_flow(flow0, completion)?; + + Ok(tonic::Response::new(CompleteOAuthDeviceCodeResponse {})) + } + #[tracing::instrument(skip_all)] async fn refresh_aws_sso_client_registration( &self, @@ -325,7 +387,13 @@ impl Agent { ) -> tonic::Result<()> { let token = match completion { Ok(t) => t, - Err(crate::Error::AuthNotReadyError) => { + Err(crate::Error::AuthNotReadyError { slow_down: true }) => { + tracing::debug!(flow = ?flow.as_ref(), "not yet ready, slow down"); + return Err(tonic::Status::resource_exhausted( + "not yet ready, slow down".to_string(), + )); + } + Err(crate::Error::AuthNotReadyError { slow_down: false }) => { tracing::debug!(flow = ?flow.as_ref(), "not yet ready"); return Err(tonic::Status::failed_precondition( "not yet ready".to_string(), diff --git a/src/auth_flow_manager.rs b/src/auth_flow_manager.rs index 42073e6..145424f 100644 --- a/src/auth_flow_manager.rs +++ b/src/auth_flow_manager.rs @@ -5,6 +5,7 @@ pub const MAX_ITEMS: usize = 15; pub enum AuthFlow { Nop, OAuthCode(crate::oauth_code::OAuthCodeFlow), + OAuthDeviceCode(crate::oauth_device_code::OAuthDeviceCodeFlow), AwsSsoDevice(crate::oauth_awssso::AwsSsoDeviceFlow), } @@ -13,6 +14,7 @@ impl AuthFlow { match self { AuthFlow::Nop => "", AuthFlow::OAuthCode(f) => &f.handle, + AuthFlow::OAuthDeviceCode(f) => &f.handle, AuthFlow::AwsSsoDevice(f) => &f.handle, } } diff --git a/src/cmd/login.rs b/src/cmd/login.rs index ea982e7..3b014ea 100644 --- a/src/cmd/login.rs +++ b/src/cmd/login.rs @@ -32,14 +32,16 @@ pub async fn login( server.validate()?; let oauth = server.oauth.as_ref().unwrap(); - let oauth_grant_type = args - .oauth_grant_type - .unwrap_or_else(|| oauth.default_grant_type()); + let oauth_grant_type = match args.oauth_grant_type { + Some(x) => Ok(x), + None => oauth.default_grant_type(), + }?; tracing::debug!(oauth_grant_type = ?oauth_grant_type, server = ?server, "Using OAuth"); match oauth_grant_type { crate::config::OAuthGrantType::Code => do_oauth_code(agent, server).await, + crate::config::OAuthGrantType::DeviceCode => do_oauth_device_code(agent, server).await, crate::config::OAuthGrantType::AwsSso => do_awssso(agent, server).await, } } @@ -89,6 +91,66 @@ pub async fn do_oauth_code( Ok(()) } +pub async fn do_oauth_device_code( + agent: &mut crate::agent::AgentConn, + server: crate::config::Server, +) -> Result<(), anyhow::Error> { + server.try_oauth_device_code_grant()?; + + let session = agent + .initiate_oauth_device_code(crate::proto::InitiateOAuthDeviceCodeRequest { + server_id: server.id().to_owned(), + }) + .await? + .into_inner(); + tracing::debug!(session = ?session, "Initiated flow"); + + let product = env!("CARGO_PKG_NAME"); + let server_id = server.id(); + let server_url = &server.url; + let user_code = &session.user_code; + let mut authorize_url = &session.verification_uri_complete; + if authorize_url.is_empty() { + authorize_url = &session.verification_uri; + } + + crate::terminal::send(&indoc::formatdoc! {" + :: {product} :: Login to {server_id} ({server_url}) :::::::: + :: {product} :: + :: {product} :: Your Verification Code: {user_code} + :: {product} :: To authorize, visit: {authorize_url} + :: {product} :: + "}) + .await; + + let mut interval = session.interval as u64; + loop { + tokio::time::sleep(std::time::Duration::from_secs(interval)).await; + let completion = agent + .complete_oauth_device_code(crate::proto::CompleteOAuthDeviceCodeRequest { + handle: session.handle.clone(), + }) + .await; + + match completion { + Ok(_) => break, + Err(e) if e.code() == tonic::Code::ResourceExhausted => { + interval += 5; + tracing::debug!(interval = ?interval, "Received slow_down request"); + } + Err(e) if e.code() == tonic::Code::FailedPrecondition => { + // continue + } + Err(e) => { + anyhow::bail!(e); + } + } + } + + tracing::info!("Logged in"); + Ok(()) +} + pub async fn do_awssso( agent: &mut crate::agent::AgentConn, server: crate::config::Server, diff --git a/src/config.rs b/src/config.rs index 812e12b..de16ed1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -217,6 +217,24 @@ impl Server { Ok((oauth, code_grant)) } + pub fn try_oauth_device_code_grant( + &self, + ) -> crate::Result<(&ServerOAuth, &ServerDeviceCodeGrant)> { + let Some(oauth) = self.oauth.as_ref() else { + return Err(crate::Error::ConfigError(format!( + "Server '{}' is missing OAuth 2.0 client configuration", + self.id() + ))); + }; + match oauth.device_code_grant { + Some(ref grant) => Ok((oauth, grant)), + None => Err(crate::Error::ConfigError(format!( + "Server '{}' is missing OAuth 2.0 Device Code Grant configuration", + self.id() + ))), + } + } + pub fn try_oauth_awssso(&self) -> crate::Result<&ServerOAuth> { if self.aws_sso.is_none() { return Err(crate::Error::ConfigError(format!( @@ -307,6 +325,7 @@ impl TryFrom for Server { #[serde(rename_all = "snake_case")] pub enum OAuthGrantType { Code, + DeviceCode, AwsSso, } @@ -315,6 +334,7 @@ impl std::str::FromStr for OAuthGrantType { fn from_str(s: &str) -> Result { match s { "code" => Ok(OAuthGrantType::Code), + "device_code" => Ok(OAuthGrantType::DeviceCode), "aws_sso" => Ok(OAuthGrantType::AwsSso), _ => Err(crate::Error::UserError( "unknown oauth_grant_type".to_owned(), @@ -333,7 +353,7 @@ pub struct ServerOAuth { pub scope: Vec, default_grant_type: Option, pub code_grant: Option, - pub device_grant: Option, + pub device_code_grant: Option, /// Expiration time of dynamically registered OAuth2 client registration, such as AWS SSO /// clients. @@ -346,14 +366,15 @@ fn default_oauth_scope() -> Vec { impl ServerOAuth { pub fn validate(&self) -> Result<(), crate::Error> { - if self.code_grant.is_none() && self.device_grant.is_none() { + if self.code_grant.is_none() && self.device_code_grant.is_none() { return Err(crate::Error::ConfigError( - "Either oauth.code_grant or oauth.device_grant must be provided, but absent" + "Either oauth.code_grant or oauth.device_code_grant must be provided, but absent" .to_owned(), )); } if match self.default_grant_type { None => false, + Some(OAuthGrantType::DeviceCode) => self.device_code_grant.is_none(), Some(OAuthGrantType::Code) => self.code_grant.is_none(), Some(OAuthGrantType::AwsSso) => false, } { @@ -364,16 +385,21 @@ impl ServerOAuth { Ok(()) } - pub fn default_grant_type(&self) -> OAuthGrantType { - self.default_grant_type.unwrap_or_else(|| { - if self.code_grant.is_some() { - return OAuthGrantType::Code; - } - if self.device_grant.is_some() { - // TODO: implement + pub fn default_grant_type(&self) -> Result { + match self.default_grant_type { + Some(x) => Ok(x), + None => { + if self.code_grant.is_some() { + return Ok(OAuthGrantType::Code); + } + if self.device_code_grant.is_some() { + return Ok(OAuthGrantType::DeviceCode); + } + Err(crate::Error::ConfigError( + "cannot determine default grant_type".to_string(), + )) } - unreachable!(); - }) + } } } @@ -384,7 +410,7 @@ pub struct ServerCodeGrant { } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] -pub struct ServerDeviceGrant { +pub struct ServerDeviceCodeGrant { pub device_authorization_endpoint: Option, } @@ -435,7 +461,7 @@ impl AwsSsoClientRegistrationCache { token_endpoint: None, scope: sso.scope.clone(), code_grant: None, - device_grant: None, + device_code_grant: None, client_expires_at: Some(self.expires_at), } } diff --git a/src/error.rs b/src/error.rs index ab78165..5e2cc37 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,7 +22,7 @@ pub enum Error { UrlParseError(#[from] url::ParseError), #[error("AuthNotReadyError: flow not yet ready")] - AuthNotReadyError, + AuthNotReadyError { slow_down: bool }, #[error(transparent)] OAuth2RequestTokenError( diff --git a/src/ext_oauth2.rs b/src/ext_oauth2.rs index 6f80169..9d50e87 100644 --- a/src/ext_oauth2.rs +++ b/src/ext_oauth2.rs @@ -17,13 +17,18 @@ pub type SecrecyClient< HasTokenUrl, >; +#[serde_with::serde_as] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] pub struct SecrecyTokenResponse { #[serde(skip_serializing)] pub access_token: secrecy::SecretString, #[serde(deserialize_with = "oauth2::helpers::deserialize_untagged_enum_case_insensitive")] pub token_type: oauth2::basic::BasicTokenType, + + // Non-compliant servers (e.g. Microsoft) may return this in string + #[serde_as(deserialize_as = "Option>")] pub expires_in: Option, + #[serde(skip_serializing)] pub refresh_token: Option, @@ -57,3 +62,105 @@ impl oauth2::TokenResponse for SecrecyTokenResponse { None } } + +#[serde_with::serde_as] +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct CustomDeviceAuthorizationResponse { + #[serde(skip_serializing)] + pub device_code: secrecy::SecretString, + #[serde(skip_serializing)] + pub user_code: secrecy::SecretString, + + #[serde(alias = "verification_url")] + pub verification_uri: String, + + #[serde(skip_serializing)] + pub verification_uri_complete: Option, + + // Non-compliant servers (e.g. Microsoft OAuth 2 v1 endpoint) may return in string + #[serde_as(deserialize_as = "serde_with::PickFirst<(_, serde_with::DisplayFromStr)>")] + pub expires_in: u64, + + #[serde_as( + deserialize_as = "DeviceCodeAuthMinimum>" + )] + #[serde(default = "default_device_auth_interval")] + pub interval: i32, +} + +// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2 +pub static DEVICE_CODE_AUTH_INTERVAL_MIN: i32 = 5; + +fn default_device_auth_interval() -> i32 { + DEVICE_CODE_AUTH_INTERVAL_MIN +} + +struct DeviceCodeAuthMinimum(std::marker::PhantomData); + +impl<'de, TAs> serde_with::DeserializeAs<'de, i32> for DeviceCodeAuthMinimum +where + TAs: serde_with::DeserializeAs<'de, i32>, +{ + fn deserialize_as(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::Deserialize; + let content = + >::deserialize(deserializer)?.into_inner(); + Ok(content.max(DEVICE_CODE_AUTH_INTERVAL_MIN)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + mod custom_device_authorization_response { + use super::*; + + #[test] + fn test_parse_minimal() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": 1234}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.expires_in, 1234); + assert_eq!(d.interval, 5); + } + #[test] + fn test_parse_expires_in_str() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": "2345"}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.expires_in, 2345); + } + #[test] + fn test_parse_interval() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": 1234, "interval": 10}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.interval, 10); + } + #[test] + fn test_parse_interval_str() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": 1234, "interval": "20"}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.interval, 20); + } + #[test] + fn test_parse_interval_null() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": 1234, "interval": null}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.interval, 5); + } + #[test] + fn test_parse_interval_zero() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": 1234, "interval": 0}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.interval, 5); + } + #[test] + fn test_parse_interval_negative() { + let j = r#"{"device_code": "DEVICE", "user_code": "USER", "verification_uri": "http://test.invalid", "expires_in": 1234, "interval": -10}"#; + let d: CustomDeviceAuthorizationResponse = serde_json::from_str(j).unwrap(); + assert_eq!(d.interval, 5); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 16f786c..7526b10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ pub mod utils; pub mod oauth_awssso; pub mod oauth_code; +pub mod oauth_device_code; pub mod api_client; pub mod awssso_client; diff --git a/src/oauth_awssso.rs b/src/oauth_awssso.rs index 6be9b2e..7497a8b 100644 --- a/src/oauth_awssso.rs +++ b/src/oauth_awssso.rs @@ -192,7 +192,7 @@ impl AwsSsoDeviceFlow { Err(aws_sdk_ssooidc::error::SdkError::ServiceError(e)) if e.err().is_authorization_pending_exception() => { - Err(crate::Error::AuthNotReadyError) + Err(crate::Error::AuthNotReadyError { slow_down: false }) } Err(e) => Err(Box::new(aws_sdk_ssooidc::Error::from(e)).into()), } diff --git a/src/oauth_device_code.rs b/src/oauth_device_code.rs new file mode 100644 index 0000000..4ca7ba0 --- /dev/null +++ b/src/oauth_device_code.rs @@ -0,0 +1,147 @@ +#[derive(Debug)] +pub struct OAuthDeviceCodeFlow { + pub handle: String, + server: crate::config::Server, + pub user_code: secrecy::SecretString, + device_code: secrecy::SecretString, + pub expires_at: chrono::DateTime, + pub verification_uri: String, + pub verification_uri_complete: Option, + pub interval: i32, + //response: oauth2::StandardDeviceAuthorizationResponse, +} + +impl From<&OAuthDeviceCodeFlow> for crate::proto::InitiateOAuthDeviceCodeResponse { + fn from(flow: &OAuthDeviceCodeFlow) -> crate::proto::InitiateOAuthDeviceCodeResponse { + use secrecy::ExposeSecret; + crate::proto::InitiateOAuthDeviceCodeResponse { + handle: flow.handle.clone(), + user_code: flow.user_code.expose_secret().to_owned(), + verification_uri: flow.verification_uri.clone(), + verification_uri_complete: flow + .verification_uri_complete + .as_ref() + .map(|x| x.expose_secret().to_owned()) + .unwrap_or_default(), + interval: flow.interval, + expires_at: Some(std::time::SystemTime::from(flow.expires_at).into()), + } + } +} + +impl OAuthDeviceCodeFlow { + pub async fn initiate(server: &crate::config::Server) -> crate::Result { + let (oauth, grant) = server.try_oauth_device_code_grant()?; + let handle = crate::utils::generate_flow_handle(); + tracing::info!(server = ?server, handle = ?handle, "Initiating OAuth 2.0 Device Code flow"); + + // oauth2 crate doesn't allow non-standard response type + let scopes = oauth.scope.join(" "); + let resp = crate::client::http() + .post(grant.device_authorization_endpoint.clone().ok_or_else(|| { + crate::Error::ConfigError(format!( + "{} is missing device_authorization_endpoint", + server.id() + )) + })?) + .header(reqwest::header::ACCEPT, "application/json") + .basic_auth(&oauth.client_id, oauth.client_secret.as_ref()) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("scope", &scopes), + ]) + .send() + .await?; + + let status = resp.status(); + if status != reqwest::StatusCode::OK { + let body = resp.bytes().await?; + if body.is_empty() { + return Err(crate::Error::AuthError(format!( + "server returned empty error; {status}" + ))); + } else { + let er: oauth2::basic::BasicErrorResponse = serde_json::from_slice(&body)?; + tracing::debug!(response = ?er, "DeviceCodeErrorResponse"); + let be = er.error(); + return Err(crate::Error::AuthError(format!( + "oauth2 error {be}: {er:?}" + ))); + } + } + let body: crate::ext_oauth2::CustomDeviceAuthorizationResponse = resp.json().await?; + + Ok(Self { + handle, + server: server.to_owned(), + user_code: body.user_code, + device_code: body.device_code, + verification_uri: body.verification_uri, + verification_uri_complete: body.verification_uri_complete, + expires_at: chrono::Utc::now() + chrono::TimeDelta::seconds(body.expires_in as i64), + interval: body + .interval + .max(crate::ext_oauth2::DEVICE_CODE_AUTH_INTERVAL_MIN), + }) + } + + pub async fn complete(&self) -> crate::Result { + use secrecy::ExposeSecret; + let (oauth, _) = self.server.try_oauth_device_code_grant()?; + tracing::info!(flow = ?self, "Completing OAuth 2.0 Device Code flow"); + + // oauth2 crate doesn't allow sending request only once + + let req = crate::client::http() + .post(oauth.token_endpoint.clone().ok_or_else(|| { + crate::Error::ConfigError(format!("{} is missing token_endpoint", self.server.id())) + })?) + .header(reqwest::header::ACCEPT, "application/json") + .basic_auth(&oauth.client_id, oauth.client_secret.as_ref()) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("device_code", self.device_code.expose_secret()), + ]) + .build()?; + tracing::debug!(req = ?req, "complete req"); + let resp = crate::client::http().execute(req).await?; + + let status = resp.status(); + if status != reqwest::StatusCode::OK { + let body = resp.bytes().await?; + if body.is_empty() { + return Err(crate::Error::AuthError(format!( + "server returned empty error; {status}" + ))); + } else { + let er: oauth2::DeviceCodeErrorResponse = serde_json::from_slice(&body)?; + tracing::debug!(response = ?er, "DeviceCodeErrorResponse"); + return match er.error() { + oauth2::DeviceCodeErrorResponseType::AuthorizationPending => { + Err(crate::Error::AuthNotReadyError { slow_down: false }) + } + oauth2::DeviceCodeErrorResponseType::SlowDown => { + Err(crate::Error::AuthNotReadyError { slow_down: true }) + } + oauth2::DeviceCodeErrorResponseType::AccessDenied => { + Err(crate::Error::AuthError(format!("access_denied: {er:?}"))) + } + oauth2::DeviceCodeErrorResponseType::ExpiredToken => { + Err(crate::Error::AuthError(format!( + "authorization timed out (device token expired): {er:?}" + ))) + } + oauth2::DeviceCodeErrorResponseType::Basic(be) => Err(crate::Error::AuthError( + format!("oauth2 error {be}: {er:?}"), + )), + }; + } + } + let body: crate::ext_oauth2::SecrecyTokenResponse = resp.json().await?; + + Ok(crate::token::ServerToken::from_token_response( + self.server.clone(), + body, + )) + } +}