mirror of
https://github.com/librespot-org/librespot.git
synced 2024-12-18 17:11:53 +00:00
OAuth open URL in default browser, set success message and throw descriptive errors
This commit is contained in:
parent
82076e882f
commit
814d30bd49
6 changed files with 188 additions and 19 deletions
|
@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
## [Unreleased] - YYYY-MM-DD
|
## [Unreleased] - YYYY-MM-DD
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
- [oauth] Open authorization URL in default browser
|
||||||
|
- [oauth] Allow optionally passing success message to display on browser return page
|
||||||
|
- [oauth] Throw specific errors on failure states
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
|
37
Cargo.lock
generated
37
Cargo.lock
generated
|
@ -1681,6 +1681,25 @@ version = "2.10.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
|
checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "is-docker"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "is-wsl"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5"
|
||||||
|
dependencies = [
|
||||||
|
"is-docker",
|
||||||
|
"once_cell",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "is_terminal_polyfill"
|
name = "is_terminal_polyfill"
|
||||||
version = "1.70.1"
|
version = "1.70.1"
|
||||||
|
@ -2061,6 +2080,7 @@ dependencies = [
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"log",
|
"log",
|
||||||
"oauth2",
|
"oauth2",
|
||||||
|
"open",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
|
@ -2471,6 +2491,17 @@ version = "1.20.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "open"
|
||||||
|
version = "5.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3ecd52f0b8d15c40ce4820aa251ed5de032e5d91fab27f7db2f40d42a8bdf69c"
|
||||||
|
dependencies = [
|
||||||
|
"is-wsl",
|
||||||
|
"libc",
|
||||||
|
"pathdiff",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openssl-probe"
|
name = "openssl-probe"
|
||||||
version = "0.1.5"
|
version = "0.1.5"
|
||||||
|
@ -2534,6 +2565,12 @@ version = "1.0.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pathdiff"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d61c5ce1153ab5b689d0c074c4e7fc613e942dfb7dd9eea5ab202d2ad91fe361"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pbkdf2"
|
name = "pbkdf2"
|
||||||
version = "0.12.2"
|
version = "0.12.2"
|
||||||
|
|
|
@ -13,6 +13,7 @@ log = "0.4"
|
||||||
oauth2 = "4.4"
|
oauth2 = "4.4"
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
url = "2.2"
|
url = "2.2"
|
||||||
|
open = "5.3.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
env_logger = { version = "0.11.2", default-features = false, features = ["color", "humantime", "auto-color"] }
|
env_logger = { version = "0.11.2", default-features = false, features = ["color", "humantime", "auto-color"] }
|
||||||
|
|
|
@ -25,7 +25,7 @@ fn main() {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
match get_access_token(client_id, redirect_uri, scopes) {
|
match get_access_token(client_id, redirect_uri, scopes, None) {
|
||||||
Ok(token) => println!("Success: {token:#?}"),
|
Ok(token) => println!("Success: {token:#?}"),
|
||||||
Err(e) => println!("Failed: {e}"),
|
Err(e) => println!("Failed: {e}"),
|
||||||
};
|
};
|
||||||
|
|
163
oauth/src/lib.rs
163
oauth/src/lib.rs
|
@ -34,6 +34,9 @@ pub enum OAuthError {
|
||||||
#[error("Auth code param not found in URI {uri}")]
|
#[error("Auth code param not found in URI {uri}")]
|
||||||
AuthCodeNotFound { uri: String },
|
AuthCodeNotFound { uri: String },
|
||||||
|
|
||||||
|
#[error("CSRF token param not found in URI {uri}")]
|
||||||
|
CsrfTokenNotFound { uri: String },
|
||||||
|
|
||||||
#[error("Failed to read redirect URI from stdin")]
|
#[error("Failed to read redirect URI from stdin")]
|
||||||
AuthCodeStdinRead,
|
AuthCodeStdinRead,
|
||||||
|
|
||||||
|
@ -63,6 +66,12 @@ pub enum OAuthError {
|
||||||
|
|
||||||
#[error("Failed to exchange code for access token ({e})")]
|
#[error("Failed to exchange code for access token ({e})")]
|
||||||
ExchangeCode { e: String },
|
ExchangeCode { e: String },
|
||||||
|
|
||||||
|
#[error("Spotify did not provide a refresh token")]
|
||||||
|
NoRefreshToken,
|
||||||
|
|
||||||
|
#[error("Spotify did not return the token scopes")]
|
||||||
|
NoTokenScopes,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -74,20 +83,38 @@ pub struct OAuthToken {
|
||||||
pub scopes: Vec<String>,
|
pub scopes: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return code query-string parameter from the redirect URI.
|
/// Return URL from the redirect URI &str.
|
||||||
fn get_code(redirect_url: &str) -> Result<AuthorizationCode, OAuthError> {
|
fn get_url(redirect_url: &str) -> Result<Url, OAuthError> {
|
||||||
let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
|
let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
|
||||||
uri: redirect_url.to_string(),
|
uri: redirect_url.to_string(),
|
||||||
e,
|
e,
|
||||||
})?;
|
})?;
|
||||||
let code = url
|
Ok(url)
|
||||||
.query_pairs()
|
}
|
||||||
.find(|(key, _)| key == "code")
|
|
||||||
.map(|(_, code)| AuthorizationCode::new(code.into_owned()))
|
|
||||||
.ok_or(OAuthError::AuthCodeNotFound {
|
|
||||||
uri: redirect_url.to_string(),
|
|
||||||
})?;
|
|
||||||
|
|
||||||
|
/// Return a query-string parameter from the redirect URI.
|
||||||
|
fn get_query_string_parameter(url: &Url, query_string_parameter_key: &str) -> Option<String> {
|
||||||
|
url.query_pairs()
|
||||||
|
.find(|(key, _)| key == query_string_parameter_key)
|
||||||
|
.map(|(_, query_string_parameter)| query_string_parameter.into_owned())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return state query-string parameter from the redirect URI (CSRF token).
|
||||||
|
fn get_state(url: &Url) -> Result<String, OAuthError> {
|
||||||
|
let state = get_query_string_parameter(url, "state").ok_or(OAuthError::CsrfTokenNotFound {
|
||||||
|
uri: url.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return code query-string parameter from the redirect URI.
|
||||||
|
fn get_code(url: &Url) -> Result<AuthorizationCode, OAuthError> {
|
||||||
|
let code = get_query_string_parameter(url, "code")
|
||||||
|
.map(AuthorizationCode::new)
|
||||||
|
.ok_or(OAuthError::AuthCodeNotFound {
|
||||||
|
uri: url.to_string(),
|
||||||
|
})?;
|
||||||
Ok(code)
|
Ok(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,11 +127,16 @@ fn get_authcode_stdin() -> Result<AuthorizationCode, OAuthError> {
|
||||||
.read_line(&mut buffer)
|
.read_line(&mut buffer)
|
||||||
.map_err(|_| OAuthError::AuthCodeStdinRead)?;
|
.map_err(|_| OAuthError::AuthCodeStdinRead)?;
|
||||||
|
|
||||||
get_code(buffer.trim())
|
let url = get_url(buffer.trim())?;
|
||||||
|
get_code(&url)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawn HTTP server at provided socket address to accept OAuth callback and return auth code.
|
/// Spawn HTTP server at provided socket address to accept OAuth callback and return auth code.
|
||||||
fn get_authcode_listener(socket_address: SocketAddr) -> Result<AuthorizationCode, OAuthError> {
|
fn get_authcode_listener(
|
||||||
|
socket_address: SocketAddr,
|
||||||
|
csrf_token: CsrfToken,
|
||||||
|
success_message: Option<String>,
|
||||||
|
) -> Result<AuthorizationCode, OAuthError> {
|
||||||
let listener =
|
let listener =
|
||||||
TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind {
|
TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind {
|
||||||
addr: socket_address,
|
addr: socket_address,
|
||||||
|
@ -128,19 +160,28 @@ fn get_authcode_listener(socket_address: SocketAddr) -> Result<AuthorizationCode
|
||||||
.split_whitespace()
|
.split_whitespace()
|
||||||
.nth(1)
|
.nth(1)
|
||||||
.ok_or(OAuthError::AuthCodeListenerParse)?;
|
.ok_or(OAuthError::AuthCodeListenerParse)?;
|
||||||
let code = get_code(&("http://localhost".to_string() + redirect_url));
|
|
||||||
|
let url = get_url(&("http://localhost".to_string() + redirect_url))?;
|
||||||
|
|
||||||
|
let token = get_state(&url)?;
|
||||||
|
if !token.eq(csrf_token.secret()) {
|
||||||
|
return Err(OAuthError::CsrfTokenNotFound {
|
||||||
|
uri: redirect_url.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let code = get_code(&url)?;
|
||||||
|
|
||||||
let message = "Go back to your terminal :)";
|
let message = "Go back to your terminal :)";
|
||||||
let response = format!(
|
let response = format!(
|
||||||
"HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
|
"HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
|
||||||
message.len(),
|
message.len(),
|
||||||
message
|
success_message.unwrap_or(message.to_owned())
|
||||||
);
|
);
|
||||||
stream
|
stream
|
||||||
.write_all(response.as_bytes())
|
.write_all(response.as_bytes())
|
||||||
.map_err(|_| OAuthError::AuthCodeListenerWrite)?;
|
.map_err(|_| OAuthError::AuthCodeListenerWrite)?;
|
||||||
|
|
||||||
code
|
Ok(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the specified `redirect_uri` is HTTP, loopback, and contains a port,
|
// If the specified `redirect_uri` is HTTP, loopback, and contains a port,
|
||||||
|
@ -168,6 +209,7 @@ pub fn get_access_token(
|
||||||
client_id: &str,
|
client_id: &str,
|
||||||
redirect_uri: &str,
|
redirect_uri: &str,
|
||||||
scopes: Vec<&str>,
|
scopes: Vec<&str>,
|
||||||
|
success_message: Option<String>,
|
||||||
) -> Result<OAuthToken, OAuthError> {
|
) -> Result<OAuthToken, OAuthError> {
|
||||||
let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string())
|
let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string())
|
||||||
.map_err(|_| OAuthError::InvalidSpotifyUri)?;
|
.map_err(|_| OAuthError::InvalidSpotifyUri)?;
|
||||||
|
@ -195,16 +237,19 @@ pub fn get_access_token(
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| Scope::new(s.into()))
|
.map(|s| Scope::new(s.into()))
|
||||||
.collect();
|
.collect();
|
||||||
let (auth_url, _) = client
|
let (auth_url, csrf_token) = client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
.add_scopes(request_scopes)
|
.add_scopes(request_scopes)
|
||||||
.set_pkce_challenge(pkce_challenge)
|
.set_pkce_challenge(pkce_challenge)
|
||||||
.url();
|
.url();
|
||||||
|
|
||||||
println!("Browse to: {}", auth_url);
|
println!("Browse to: {}", auth_url);
|
||||||
|
if let Err(err) = open::that(auth_url.to_string()) {
|
||||||
|
eprintln!("An error occurred when opening '{}': {}", auth_url, err)
|
||||||
|
}
|
||||||
|
|
||||||
let code = match get_socket_address(redirect_uri) {
|
let code = match get_socket_address(redirect_uri) {
|
||||||
Some(addr) => get_authcode_listener(addr),
|
Some(addr) => get_authcode_listener(addr, csrf_token, success_message),
|
||||||
_ => get_authcode_stdin(),
|
_ => get_authcode_stdin(),
|
||||||
}?;
|
}?;
|
||||||
trace!("Exchange {code:?} for access token");
|
trace!("Exchange {code:?} for access token");
|
||||||
|
@ -226,11 +271,17 @@ pub fn get_access_token(
|
||||||
|
|
||||||
let token_scopes: Vec<String> = match token.scopes() {
|
let token_scopes: Vec<String> = match token.scopes() {
|
||||||
Some(s) => s.iter().map(|s| s.to_string()).collect(),
|
Some(s) => s.iter().map(|s| s.to_string()).collect(),
|
||||||
_ => scopes.into_iter().map(|s| s.to_string()).collect(),
|
None => {
|
||||||
|
error!("Spotify did not return the token scopes.");
|
||||||
|
return Err(OAuthError::NoTokenScopes);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let refresh_token = match token.refresh_token() {
|
let refresh_token = match token.refresh_token() {
|
||||||
Some(t) => t.secret().to_string(),
|
Some(t) => t.secret().to_string(),
|
||||||
_ => "".to_string(), // Spotify always provides a refresh token.
|
None => {
|
||||||
|
error!("Spotify did not provide a refresh token.");
|
||||||
|
return Err(OAuthError::NoRefreshToken);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Ok(OAuthToken {
|
Ok(OAuthToken {
|
||||||
access_token: token.access_token().secret().to_string(),
|
access_token: token.access_token().secret().to_string(),
|
||||||
|
@ -284,4 +335,80 @@ mod test {
|
||||||
Some(localhost_v6)
|
Some(localhost_v6)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_get_url_valid() {
|
||||||
|
let redirect_url = "https://example.com/callback?code=1234&state=abcd";
|
||||||
|
let result = get_url(redirect_url);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let url = result.unwrap();
|
||||||
|
assert_eq!(url.as_str(), redirect_url);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_url_invalid() {
|
||||||
|
let redirect_url = "invalid_url";
|
||||||
|
let result = get_url(redirect_url);
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(OAuthError::AuthCodeBadUri { uri, .. }) = result {
|
||||||
|
assert_eq!(uri, redirect_url.to_string());
|
||||||
|
} else {
|
||||||
|
panic!("Expected OAuthError::AuthCodeBadUri");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_query_string_parameter_found() {
|
||||||
|
let url = Url::parse("https://example.com/callback?code=1234&state=abcd").unwrap();
|
||||||
|
let key = "code";
|
||||||
|
let result = get_query_string_parameter(&url, key);
|
||||||
|
assert_eq!(result, Some("1234".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_query_string_parameter_not_found() {
|
||||||
|
let url = Url::parse("https://example.com/callback?code=1234&state=abcd").unwrap();
|
||||||
|
let key = "missing_key";
|
||||||
|
let result = get_query_string_parameter(&url, key);
|
||||||
|
assert!(result.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_state_valid() {
|
||||||
|
let url = Url::parse("https://example.com/callback?state=abcd").unwrap();
|
||||||
|
let result = get_state(&url);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap(), "abcd");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_state_missing() {
|
||||||
|
let url = Url::parse("https://example.com/callback").unwrap();
|
||||||
|
let result = get_state(&url);
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(OAuthError::CsrfTokenNotFound { uri }) = result {
|
||||||
|
assert_eq!(uri, url.to_string());
|
||||||
|
} else {
|
||||||
|
panic!("Expected OAuthError::CsrfTokenNotFound");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_code_valid() {
|
||||||
|
let url = Url::parse("https://example.com/callback?code=1234").unwrap();
|
||||||
|
let result = get_code(&url);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().secret(), "1234");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_code_missing() {
|
||||||
|
let url = Url::parse("https://example.com/callback").unwrap();
|
||||||
|
let result = get_code(&url);
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(OAuthError::AuthCodeNotFound { uri }) = result {
|
||||||
|
assert_eq!(uri, url.to_string());
|
||||||
|
} else {
|
||||||
|
panic!("Expected OAuthError::AuthCodeNotFound");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1896,6 +1896,7 @@ async fn main() {
|
||||||
&setup.session_config.client_id,
|
&setup.session_config.client_id,
|
||||||
&format!("http://127.0.0.1{port_str}/login"),
|
&format!("http://127.0.0.1{port_str}/login"),
|
||||||
OAUTH_SCOPES.to_vec(),
|
OAUTH_SCOPES.to_vec(),
|
||||||
|
None,
|
||||||
) {
|
) {
|
||||||
Ok(token) => token.access_token,
|
Ok(token) => token.access_token,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
Loading…
Reference in a new issue