From 28a00dd2f3b8d2873d2960e827e3cc0da28c79f6 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Tue, 17 Mar 2026 03:42:38 +0000 Subject: [PATCH] feat: configurable OAuth (Google + TikTok SSO), project membership, inline file preview - Auth: configurable OAuthProvider enum supporting Google OAuth and TikTok SSO - Auth: /auth/provider endpoint for frontend to detect active provider - Auth: user role system (admin via ADMIN_USERS env var sees all projects) - Projects: project_members many-to-many table with role (owner/member) - Projects: membership-based access control, auto-add creator as owner - Projects: member management API (list/add/remove) - Files: remove Content-Disposition attachment header, let browser decide - Health: public /tori/api/health endpoint for k8s probes --- src/api/auth.rs | 228 +++++++++++++++++++++++++------ src/api/files.rs | 16 +-- src/api/mod.rs | 1 + src/api/projects.rs | 196 +++++++++++++++++++++++--- src/db.rs | 73 ++++++++++ src/main.rs | 50 +++++-- web/src/components/LoginPage.vue | 38 +++++- 7 files changed, 504 insertions(+), 98 deletions(-) diff --git a/src/api/auth.rs b/src/api/auth.rs index 0f0d1b5..677fd3f 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -19,10 +19,70 @@ const CSRF_COOKIE: &str = "tori_session_csrf"; const COOKIE_PATH: &str = "/"; const SESSION_SECS: i64 = 7 * 86400; +#[derive(Debug, Clone)] +pub enum OAuthProvider { + Google { + client_id: String, + client_secret: String, + }, + TikTokSso { + client_id: String, + client_secret: String, + }, +} + +impl OAuthProvider { + fn authorize_url(&self) -> &str { + match self { + Self::Google { .. } => "https://accounts.google.com/o/oauth2/v2/auth", + Self::TikTokSso { .. } => "https://sso.tiktok-intl.com/oauth2/authorize", + } + } + + fn token_url(&self) -> &str { + match self { + Self::Google { .. } => "https://oauth2.googleapis.com/token", + Self::TikTokSso { .. } => "https://sso.tiktok-intl.com/oauth2/access_token", + } + } + + fn userinfo_url(&self) -> Option<&str> { + match self { + Self::Google { .. } => None, // uses id_token + Self::TikTokSso { .. } => Some("https://sso.tiktok-intl.com/oauth2/userinfo"), + } + } + + fn client_id(&self) -> &str { + match self { + Self::Google { client_id, .. } | Self::TikTokSso { client_id, .. } => client_id, + } + } + + fn client_secret(&self) -> &str { + match self { + Self::Google { client_secret, .. } | Self::TikTokSso { client_secret, .. } => client_secret, + } + } + + fn scope(&self) -> &str { + match self { + Self::Google { .. } => "openid%20email%20profile", + Self::TikTokSso { .. } => "read", + } + } + + fn name(&self) -> &str { + match self { + Self::Google { .. } => "google", + Self::TikTokSso { .. } => "tiktok-sso", + } + } +} + #[derive(Debug, Clone)] pub struct AuthConfig { - pub google_client_id: String, - pub google_client_secret: String, + pub provider: OAuthProvider, pub jwt_secret: String, pub public_url: String, } @@ -110,7 +170,7 @@ async fn generate_token( } } -// --- Google OAuth --- +// --- OAuth login/callback --- pub fn router(state: Arc) -> Router { Router::new() @@ -119,9 +179,15 @@ pub fn router(state: Arc) -> Router { .route("/me", get(me)) .route("/logout", post(logout)) .route("/token", post(generate_token)) + .route("/provider", get(get_provider)) .with_state(state) } +async fn get_provider(State(state): State>) -> impl IntoResponse { + let provider = state.auth.as_ref().map(|a| a.provider.name()); + Json(serde_json::json!({ "provider": provider })) +} + fn build_cookie(name: &str, value: String, max_age_secs: i64) -> Cookie<'static> { let mut c = Cookie::new(name.to_owned(), value); c.set_path(COOKIE_PATH); @@ -144,12 +210,14 @@ async fn login(State(state): State>) -> Response { let csrf = uuid::Uuid::new_v4().to_string(); let redirect_uri = format!("{}/tori/api/auth/callback", auth.public_url); + let provider = &auth.provider; + let url = format!( - "https://accounts.google.com/o/oauth2/v2/auth?\ - client_id={}&redirect_uri={}&response_type=code&\ - scope=openid%20email%20profile&access_type=online&state={}", - pct_encode(&auth.google_client_id), + "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}", + provider.authorize_url(), + pct_encode(provider.client_id()), pct_encode(&redirect_uri), + provider.scope(), pct_encode(&csrf), ); @@ -165,16 +233,14 @@ struct CallbackParams { #[derive(Deserialize)] struct TokenResponse { + access_token: Option, id_token: Option, } -#[derive(Deserialize)] -struct GoogleUserInfo { +struct UserInfo { sub: String, email: String, - #[serde(default)] name: String, - #[serde(default)] picture: String, } @@ -187,6 +253,7 @@ async fn callback( Some(a) => a, None => return (StatusCode::SERVICE_UNAVAILABLE, "Auth not configured").into_response(), }; + let provider = &auth.provider; // CSRF check match jar.get(CSRF_COOKIE) { @@ -199,11 +266,11 @@ async fn callback( // Exchange code for token let client = reqwest::Client::new(); let token_res = client - .post("https://oauth2.googleapis.com/token") + .post(provider.token_url()) .form(&[ ("code", params.code.as_str()), - ("client_id", &auth.google_client_id), - ("client_secret", &auth.google_client_secret), + ("client_id", provider.client_id()), + ("client_secret", provider.client_secret()), ("redirect_uri", &redirect_uri), ("grant_type", "authorization_code"), ]) @@ -217,44 +284,72 @@ async fn callback( }, Ok(r) => { let body = r.text().await.unwrap_or_default(); - tracing::error!("Google token exchange failed: {}", body); - return (StatusCode::BAD_GATEWAY, "Google token exchange failed").into_response(); + tracing::error!("{} token exchange failed: {}", provider.name(), body); + return (StatusCode::BAD_GATEWAY, "Token exchange failed").into_response(); } Err(e) => return (StatusCode::BAD_GATEWAY, format!("Token request failed: {}", e)).into_response(), }; - let id_token = match token_body.id_token { - Some(t) => t, - None => return (StatusCode::BAD_GATEWAY, "No id_token in response").into_response(), - }; - - // Decode id_token payload (no verification needed - just received from Google over HTTPS) - let user_info = match decode_google_id_token(&id_token) { - Some(u) => u, - None => return (StatusCode::BAD_GATEWAY, "Failed to decode id_token").into_response(), + // Get user info — provider-specific + let user_info = match provider.userinfo_url() { + Some(userinfo_url) => { + // TikTok SSO: call userinfo endpoint with access_token + let access_token = match &token_body.access_token { + Some(t) => t, + None => return (StatusCode::BAD_GATEWAY, "No access_token in response").into_response(), + }; + match fetch_userinfo(&client, userinfo_url, access_token).await { + Ok(u) => u, + Err(e) => return (StatusCode::BAD_GATEWAY, e).into_response(), + } + } + None => { + // Google: decode id_token + let id_token = match &token_body.id_token { + Some(t) => t, + None => return (StatusCode::BAD_GATEWAY, "No id_token in response").into_response(), + }; + match decode_jwt_payload(id_token) { + Some(u) => u, + None => return (StatusCode::BAD_GATEWAY, "Failed to decode id_token").into_response(), + } + } }; // Upsert user - let user_id = format!("google:{}", user_info.sub); + let user_id = format!("{}:{}", provider.name(), user_info.sub); + + // Determine user role: check ADMIN_USERS env var (comma-separated emails or usernames) + let role = { + let admin_list = std::env::var("ADMIN_USERS").unwrap_or_default(); + let is_admin = !admin_list.is_empty() && admin_list.split(',').any(|a| { + let a = a.trim(); + a == user_info.email || a == user_info.name || a == user_info.sub + }); + if is_admin { "admin" } else { "user" } + }; + let _ = sqlx::query( - "INSERT INTO users (id, email, name, picture) - VALUES (?, ?, ?, ?) + "INSERT INTO users (id, email, name, picture, role) + VALUES (?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET email = excluded.email, name = excluded.name, picture = excluded.picture, + role = excluded.role, last_login_at = datetime('now')" ) .bind(&user_id) .bind(&user_info.email) .bind(&user_info.name) .bind(&user_info.picture) + .bind(role) .execute(&state.db.pool) .await; tracing::info!("User logged in: {} ({})", user_info.email, user_id); - // Sign JWT + // Sign session JWT let exp = chrono::Utc::now().timestamp() + SESSION_SECS; let claims = Claims { sub: user_id, @@ -288,14 +383,14 @@ async fn me(State(state): State>, jar: CookieJar) -> Response { }; #[derive(Serialize)] - struct UserInfo { + struct MeResponse { id: String, email: String, name: String, picture: String, } - let user: Option = sqlx::query_as::<_, (String, String, String, String)>( + let user: Option = sqlx::query_as::<_, (String, String, String, String)>( "SELECT id, email, name, picture FROM users WHERE id = ?" ) .bind(&claims.sub) @@ -303,7 +398,7 @@ async fn me(State(state): State>, jar: CookieJar) -> Response { .await .ok() .flatten() - .map(|(id, email, name, picture)| UserInfo { id, email, name, picture }); + .map(|(id, email, name, picture)| MeResponse { id, email, name, picture }); match user { Some(u) => Json(u).into_response(), @@ -349,8 +444,57 @@ fn extract_claims(jar: &CookieJar, jwt_secret: &str) -> Option { .map(|d| d.claims) } -fn decode_google_id_token(id_token: &str) -> Option { - let parts: Vec<&str> = id_token.split('.').collect(); +/// Fetch user info from an OAuth userinfo endpoint (TikTok SSO style) +async fn fetch_userinfo( + client: &reqwest::Client, + url: &str, + access_token: &str, +) -> Result { + #[derive(Deserialize)] + struct Raw { + #[serde(default)] + sub: String, + #[serde(default)] + email: String, + #[serde(default)] + name: String, + } + + let resp = client + .get(url) + .bearer_auth(access_token) + .send() + .await + .map_err(|e| format!("Userinfo request failed: {}", e))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(format!("Userinfo failed: {}", body)); + } + + let raw: Raw = resp.json().await.map_err(|e| format!("Userinfo parse error: {}", e))?; + Ok(UserInfo { + sub: raw.sub, + email: raw.email, + name: raw.name, + picture: String::new(), + }) +} + +/// Decode JWT payload without verification (for Google id_token received over HTTPS) +fn decode_jwt_payload(jwt: &str) -> Option { + #[derive(Deserialize)] + struct Raw { + sub: String, + #[serde(default)] + email: String, + #[serde(default)] + name: String, + #[serde(default)] + picture: String, + } + + let parts: Vec<&str> = jwt.split('.').collect(); if parts.len() != 3 { return None; } @@ -359,14 +503,16 @@ fn decode_google_id_token(id_token: &str) -> Option { 3 => format!("{}=", parts[1]), _ => parts[1].to_string(), }; - let payload = base64_decode_url_safe(&padded)?; - serde_json::from_slice(&payload).ok() -} - -fn base64_decode_url_safe(input: &str) -> Option> { - let standard = input.replace('-', "+").replace('_', "/"); + let standard = padded.replace('-', "+").replace('_', "/"); use base64::Engine; - base64::engine::general_purpose::STANDARD.decode(&standard).ok() + let payload = base64::engine::general_purpose::STANDARD.decode(&standard).ok()?; + let raw: Raw = serde_json::from_slice(&payload).ok()?; + Some(UserInfo { + sub: raw.sub, + email: raw.email, + name: raw.name, + picture: raw.picture, + }) } fn pct_encode(s: &str) -> String { diff --git a/src/api/files.rs b/src/api/files.rs index ff8a0ce..f8feb7e 100644 --- a/src/api/files.rs +++ b/src/api/files.rs @@ -101,21 +101,7 @@ async fn get_file( let mime = mime_guess::from_path(&full) .first_or_octet_stream() .to_string(); - let filename = full - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("file"); - ( - [ - (axum::http::header::CONTENT_TYPE, mime), - ( - axum::http::header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", filename), - ), - ], - bytes, - ) - .into_response() + ([(axum::http::header::CONTENT_TYPE, mime)], bytes).into_response() } Err(_) => (StatusCode::NOT_FOUND, "File not found").into_response(), } diff --git a/src/api/mod.rs b/src/api/mod.rs index 7ad226e..ddff9f4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -106,6 +106,7 @@ async fn proxy_impl( } +#[allow(dead_code)] fn render_markdown_page(markdown: &str, title: &str) -> String { use pulldown_cmark::{Parser, Options, html}; let mut opts = Options::empty(); diff --git a/src/api/projects.rs b/src/api/projects.rs index bf932f0..b7719eb 100644 --- a/src/api/projects.rs +++ b/src/api/projects.rs @@ -5,13 +5,13 @@ use axum::{ Json, Router, }; use axum::http::Extensions; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::AppState; use crate::db::Project; use super::{ApiResult, db_err}; use super::auth::Claims; -fn owner_id(ext: &Extensions) -> &str { +fn user_id(ext: &Extensions) -> &str { ext.get::().map(|c| c.sub.as_str()).unwrap_or("") } @@ -28,20 +28,86 @@ pub struct UpdateProject { pub description: Option, } +#[derive(Serialize)] +struct MemberResponse { + user_id: String, + role: String, + email: String, + name: String, +} + +#[derive(Deserialize)] +struct AddMemberRequest { + user_id: String, + #[serde(default = "default_role")] + role: String, +} + +fn default_role() -> String { + "owner".to_string() +} + pub fn router(state: Arc) -> Router { Router::new() .route("/projects", get(list_projects).post(create_project)) .route("/projects/{id}", get(get_project).put(update_project).delete(delete_project)) + .route("/projects/{id}/members", get(list_members).post(add_member)) + .route("/projects/{id}/members/{user_id}", axum::routing::delete(remove_member)) .with_state(state) } +/// Check if user is admin (users.role = 'admin') +async fn is_admin(pool: &sqlx::SqlitePool, uid: &str) -> bool { + if uid.is_empty() { + return true; // auth not configured + } + sqlx::query_scalar::<_, bool>( + "SELECT COALESCE((SELECT role = 'admin' FROM users WHERE id = ?), 0)" + ) + .bind(uid) + .fetch_one(pool) + .await + .unwrap_or(false) +} + +/// Check if user can access a project. Admin users can access all projects. +async fn can_access(pool: &sqlx::SqlitePool, project_id: &str, uid: &str) -> bool { + if uid.is_empty() { + return true; // auth not configured + } + if is_admin(pool, uid).await { + return true; + } + sqlx::query_scalar::<_, bool>( + "SELECT COUNT(*) > 0 FROM project_members WHERE project_id = ? AND user_id = ?" + ) + .bind(project_id) + .bind(uid) + .fetch_one(pool) + .await + .unwrap_or(false) +} + async fn list_projects( State(state): State>, ext: Extensions, ) -> ApiResult> { - let uid = owner_id(&ext); + let uid = user_id(&ext); + if uid.is_empty() || is_admin(&state.db.pool, uid).await { + // Auth not configured or admin user — show all + return sqlx::query_as::<_, Project>( + "SELECT * FROM projects WHERE deleted = 0 ORDER BY updated_at DESC" + ) + .fetch_all(&state.db.pool) + .await + .map(Json) + .map_err(db_err); + } sqlx::query_as::<_, Project>( - "SELECT * FROM projects WHERE deleted = 0 AND (owner_id = ? OR owner_id = '') ORDER BY updated_at DESC" + "SELECT p.* FROM projects p \ + JOIN project_members pm ON p.id = pm.project_id \ + WHERE p.deleted = 0 AND pm.user_id = ? \ + ORDER BY p.updated_at DESC" ) .bind(uid) .fetch_all(&state.db.pool) @@ -56,8 +122,8 @@ async fn create_project( Json(input): Json, ) -> ApiResult { let id = uuid::Uuid::new_v4().to_string(); - let uid = owner_id(&ext); - sqlx::query_as::<_, Project>( + let uid = user_id(&ext); + let project = sqlx::query_as::<_, Project>( "INSERT INTO projects (id, name, description, owner_id) VALUES (?, ?, ?, ?) RETURNING *" ) .bind(&id) @@ -66,8 +132,20 @@ async fn create_project( .bind(uid) .fetch_one(&state.db.pool) .await - .map(Json) - .map_err(db_err) + .map_err(db_err)?; + + // Auto-add creator as admin member + if !uid.is_empty() { + let _ = sqlx::query( + "INSERT OR IGNORE INTO project_members (project_id, user_id, role) VALUES (?, ?, 'owner')" + ) + .bind(&id) + .bind(uid) + .execute(&state.db.pool) + .await; + } + + Ok(Json(project)) } async fn get_project( @@ -75,12 +153,14 @@ async fn get_project( ext: Extensions, Path(id): Path, ) -> ApiResult> { - let uid = owner_id(&ext); + let uid = user_id(&ext); + if !uid.is_empty() && !can_access(&state.db.pool, &id, uid).await { + return Ok(Json(None)); + } sqlx::query_as::<_, Project>( - "SELECT * FROM projects WHERE id = ? AND (owner_id = ? OR owner_id = '')" + "SELECT * FROM projects WHERE id = ?" ) .bind(&id) - .bind(uid) .fetch_optional(&state.db.pool) .await .map(Json) @@ -93,30 +173,30 @@ async fn update_project( Path(id): Path, Json(input): Json, ) -> ApiResult> { - let uid = owner_id(&ext); + let uid = user_id(&ext); + if !uid.is_empty() && !can_access(&state.db.pool, &id, uid).await { + return Ok(Json(None)); + } if let Some(name) = &input.name { - sqlx::query("UPDATE projects SET name = ?, updated_at = datetime('now') WHERE id = ? AND (owner_id = ? OR owner_id = '')") + sqlx::query("UPDATE projects SET name = ?, updated_at = datetime('now') WHERE id = ?") .bind(name) .bind(&id) - .bind(uid) .execute(&state.db.pool) .await .map_err(db_err)?; } if let Some(desc) = &input.description { - sqlx::query("UPDATE projects SET description = ?, updated_at = datetime('now') WHERE id = ? AND (owner_id = ? OR owner_id = '')") + sqlx::query("UPDATE projects SET description = ?, updated_at = datetime('now') WHERE id = ?") .bind(desc) .bind(&id) - .bind(uid) .execute(&state.db.pool) .await .map_err(db_err)?; } sqlx::query_as::<_, Project>( - "SELECT * FROM projects WHERE id = ? AND (owner_id = ? OR owner_id = '')" + "SELECT * FROM projects WHERE id = ?" ) .bind(&id) - .bind(uid) .fetch_optional(&state.db.pool) .await .map(Json) @@ -128,12 +208,14 @@ async fn delete_project( ext: Extensions, Path(id): Path, ) -> ApiResult { - let uid = owner_id(&ext); + let uid = user_id(&ext); + if !uid.is_empty() && !can_access(&state.db.pool, &id, uid).await { + return Ok(Json(false)); + } let result = sqlx::query( - "UPDATE projects SET deleted = 1, updated_at = datetime('now') WHERE id = ? AND deleted = 0 AND (owner_id = ? OR owner_id = '')" + "UPDATE projects SET deleted = 1, updated_at = datetime('now') WHERE id = ? AND deleted = 0" ) .bind(&id) - .bind(uid) .execute(&state.db.pool) .await .map_err(db_err)?; @@ -158,3 +240,75 @@ async fn delete_project( Ok(Json(true)) } + +// --- Member management --- + +async fn list_members( + State(state): State>, + ext: Extensions, + Path(id): Path, +) -> ApiResult> { + let uid = user_id(&ext); + if !uid.is_empty() && !can_access(&state.db.pool, &id, uid).await { + return Ok(Json(vec![])); + } + let members: Vec<(String, String, String, String)> = sqlx::query_as( + "SELECT pm.user_id, pm.role, COALESCE(u.email, ''), COALESCE(u.name, '') \ + FROM project_members pm \ + LEFT JOIN users u ON pm.user_id = u.id \ + WHERE pm.project_id = ? \ + ORDER BY pm.created_at ASC" + ) + .bind(&id) + .fetch_all(&state.db.pool) + .await + .map_err(db_err)?; + + Ok(Json(members.into_iter().map(|(user_id, role, email, name)| { + MemberResponse { user_id, role, email, name } + }).collect())) +} + +async fn add_member( + State(state): State>, + ext: Extensions, + Path(id): Path, + Json(input): Json, +) -> ApiResult { + let uid = user_id(&ext); + if !uid.is_empty() && !can_access(&state.db.pool, &id, uid).await { + return Ok(Json(false)); + } + let result = sqlx::query( + "INSERT OR IGNORE INTO project_members (project_id, user_id, role) VALUES (?, ?, ?)" + ) + .bind(&id) + .bind(&input.user_id) + .bind(&input.role) + .execute(&state.db.pool) + .await + .map_err(db_err)?; + + Ok(Json(result.rows_affected() > 0)) +} + +async fn remove_member( + State(state): State>, + ext: Extensions, + Path((id, member_id)): Path<(String, String)>, +) -> ApiResult { + let uid = user_id(&ext); + if !uid.is_empty() && !can_access(&state.db.pool, &id, uid).await { + return Ok(Json(false)); + } + let result = sqlx::query( + "DELETE FROM project_members WHERE project_id = ? AND user_id = ?" + ) + .bind(&id) + .bind(&member_id) + .execute(&state.db.pool) + .await + .map_err(db_err)?; + + Ok(Json(result.rows_affected() > 0)) +} diff --git a/src/db.rs b/src/db.rs index 9dc8c2f..54922e0 100644 --- a/src/db.rs +++ b/src/db.rs @@ -249,6 +249,79 @@ impl Database { .execute(&self.pool) .await?; + // Migration: add role column to users (admin = see all projects) + let _ = sqlx::query( + "ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'" + ) + .execute(&self.pool) + .await; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS project_members ( + project_id TEXT NOT NULL REFERENCES projects(id), + user_id TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'owner', + created_at TEXT NOT NULL DEFAULT (datetime('now')), + PRIMARY KEY (project_id, user_id) + )" + ) + .execute(&self.pool) + .await?; + + // Migration: assign all existing memberless projects to the first user (or leave for manual assignment) + // When auth is not configured, owner_id is empty — these projects are visible to everyone + // When a user logs in and creates projects, they get auto-added as admin + { + // Find existing projects with owner_id set but no members yet + let owned: Vec<(String, String)> = sqlx::query_as( + "SELECT p.id, p.owner_id FROM projects p \ + WHERE p.deleted = 0 AND p.owner_id != '' \ + AND NOT EXISTS (SELECT 1 FROM project_members pm WHERE pm.project_id = p.id)" + ) + .fetch_all(&self.pool) + .await + .unwrap_or_default(); + + for (pid, uid) in owned { + let _ = sqlx::query( + "INSERT OR IGNORE INTO project_members (project_id, user_id, role) VALUES (?, ?, 'owner')" + ) + .bind(&pid) + .bind(&uid) + .execute(&self.pool) + .await; + } + + // For orphan projects (no owner, no members), assign to first user if one exists + let first_user: Option<(String,)> = sqlx::query_as( + "SELECT id FROM users ORDER BY created_at ASC LIMIT 1" + ) + .fetch_optional(&self.pool) + .await + .unwrap_or(None); + + if let Some((first_uid,)) = first_user { + let orphans: Vec<(String,)> = sqlx::query_as( + "SELECT p.id FROM projects p \ + WHERE p.deleted = 0 \ + AND NOT EXISTS (SELECT 1 FROM project_members pm WHERE pm.project_id = p.id)" + ) + .fetch_all(&self.pool) + .await + .unwrap_or_default(); + + for (pid,) in orphans { + let _ = sqlx::query( + "INSERT OR IGNORE INTO project_members (project_id, user_id, role) VALUES (?, ?, 'owner')" + ) + .bind(&pid) + .bind(&first_uid) + .execute(&self.pool) + .await; + } + } + } + Ok(()) } } diff --git a/src/main.rs b/src/main.rs index 2797a8d..d70f685 100644 --- a/src/main.rs +++ b/src/main.rs @@ -123,25 +123,41 @@ async fn main() -> anyhow::Result<()> { let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string()); - let auth_config = match ( - std::env::var("GOOGLE_CLIENT_ID"), - std::env::var("GOOGLE_CLIENT_SECRET"), - ) { - (Ok(client_id), Ok(client_secret)) => { - let jwt_secret = std::env::var("JWT_SECRET") - .unwrap_or_else(|_| uuid::Uuid::new_v4().to_string()); - let public_url = std::env::var("PUBLIC_URL") - .unwrap_or_else(|_| "https://tori.euphon.cloud".to_string()); - tracing::info!("Google OAuth enabled (public_url={})", public_url); + let auth_config = { + let jwt_secret = std::env::var("JWT_SECRET") + .unwrap_or_else(|_| uuid::Uuid::new_v4().to_string()); + let public_url = std::env::var("PUBLIC_URL") + .unwrap_or_else(|_| "https://tori.euphon.cloud".to_string()); + + // Try TikTok SSO first, then Google OAuth + if let (Ok(id), Ok(secret)) = ( + std::env::var("SSO_CLIENT_ID"), + std::env::var("SSO_CLIENT_SECRET"), + ) { + tracing::info!("TikTok SSO enabled (public_url={})", public_url); Some(api::auth::AuthConfig { - google_client_id: client_id, - google_client_secret: client_secret, + provider: api::auth::OAuthProvider::TikTokSso { + client_id: id, + client_secret: secret, + }, jwt_secret, public_url, }) - } - _ => { - tracing::warn!("GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET not set, auth disabled"); + } else if let (Ok(id), Ok(secret)) = ( + std::env::var("GOOGLE_CLIENT_ID"), + std::env::var("GOOGLE_CLIENT_SECRET"), + ) { + tracing::info!("Google OAuth enabled (public_url={})", public_url); + Some(api::auth::AuthConfig { + provider: api::auth::OAuthProvider::Google { + client_id: id, + client_secret: secret, + }, + jwt_secret, + public_url, + }) + } else { + tracing::warn!("No OAuth configured (set SSO_CLIENT_ID/SSO_CLIENT_SECRET or GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET)"); None } }; @@ -156,6 +172,10 @@ async fn main() -> anyhow::Result<()> { }); let app = Router::new() + // Health check (public, for k8s probes) + .route("/tori/api/health", axum::routing::get(|| async { + axum::Json(serde_json::json!({"status": "ok"})) + })) // Auth routes are public .nest("/tori/api/auth", api::auth::router(state.clone())) // Protected API routes diff --git a/web/src/components/LoginPage.vue b/web/src/components/LoginPage.vue index d2a7c01..802c607 100644 --- a/web/src/components/LoginPage.vue +++ b/web/src/components/LoginPage.vue @@ -1,5 +1,25 @@