diff --git a/src/config.rs b/src/config.rs index 44f79b0..38c0088 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,18 +1,19 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub(crate) struct Config { pub(crate) general: GeneralConfig, pub(crate) database: DBConfig } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub(crate) struct GeneralConfig { pub(crate) listen_address: String, pub(crate) port: u16, + pub(crate) jail_dir: String } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] #[serde(tag = "driver")] pub(crate) enum DBConfig { #[serde(rename = "sqlite")] @@ -43,7 +44,8 @@ impl Default for Config { Config { general: GeneralConfig { listen_address: String::from("0.0.0.0"), - port: 2222 + port: 2222, + jail_dir: String::from("/srv/sftp") }, database: DBConfig::Sqlite { path: String::from("/var/lib/flux-sftp/auth.db") diff --git a/src/main.rs b/src/main.rs index 774c0a6..1953a9c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,14 +2,29 @@ mod sftp; mod config; use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; -use config::Config; +use config::{Config, DBConfig}; use russh::{keys::ssh_key::{rand_core::OsRng, PublicKey}, server::{Auth, Handler as SshHandler, Msg, Server, Session}, Channel, ChannelId}; use sftp::SftpSession; -use sqlx::{sqlite::SqlitePoolOptions, Pool, Row, Sqlite}; +use sqlx::{mysql::MySqlPoolOptions, postgres::PgPoolOptions, sqlite::SqlitePoolOptions, MySql, Pool, Postgres, Row, Sqlite}; use tokio::fs; +macro_rules! fetch_pub_key { + ($pool:ident, $query:literal, $user:ident) => { + { + let row_res = sqlx::query($query) + .bind($user) + .fetch_one($pool).await; + match row_res { + Ok(row) => Some(row.get("public_key")), + Err(_) => None + } + } + }; +} + struct SftpServer { - pool: Arc> + pool: Arc, + config: Arc } impl Server for SftpServer { @@ -17,14 +32,16 @@ impl Server for SftpServer { fn new_client(&mut self, _peer_addr: Option) -> Self::Handler { let session_pool = self.pool.clone(); - SshSession { channel: None, user: None, pool: session_pool } + let config = self.config.clone(); + SshSession { channel: None, user: None, pool: session_pool, config } } } struct SshSession { channel: Option>, user: Option, - pool: Arc> + pool: Arc, + config: Arc } impl SshHandler for SshSession { @@ -37,26 +54,28 @@ impl SshHandler for SshSession { ) -> Result { self.user = Some(user.to_string()); - let row_res = sqlx::query("SELECT * FROM users WHERE username = ?") - .bind(user) - .fetch_one(&*self.pool).await; + let offered_key = public_key.to_string(); - match row_res { - Ok(row) => { - let stored_key: String = row.get("public_key"); - let offered_key = public_key.to_string(); - if stored_key == offered_key { - Ok(Auth::Accept) - } - else { - Ok(Auth::reject()) - } + let stored_key_opt: Option = match &*self.pool { + DBPool::Sqlite(pool) => fetch_pub_key!(pool, "SELECT * FROM users WHERE username = ?", user), + DBPool::Postgres(pool) => fetch_pub_key!(pool, "SELECT * FROM users WHERE username = $1", user), + DBPool::Mysql(pool) => fetch_pub_key!(pool, "SELECT * FROM users WHERE username = ?", user) + }; + + if let Some(stored_key) = stored_key_opt { + if stored_key == offered_key { + Ok(Auth::Accept) } - Err(e) => { - println!("User Not found: {}", e); + else { + println!("invalid key"); Ok(Auth::reject()) } } + else { + println!("user not found"); + Ok(Auth::reject()) + } + } async fn auth_publickey( @@ -92,7 +111,7 @@ impl SshHandler for SshSession { ) -> Result<(), Self::Error> { if name == "sftp" { session.channel_success(channel_id)?; - let jail_dir = format!("/srv/sftp/{}", self.user.as_ref().unwrap()); + let jail_dir = format!("{}/{}", self.config.general.jail_dir, self.user.as_ref().unwrap()); let sftp_handler = SftpSession::new(jail_dir); russh_sftp::server::run(self.channel.take().ok_or(Self::Error::WrongChannel)?.into_stream(), sftp_handler).await; } @@ -104,16 +123,22 @@ impl SshHandler for SshSession { } +enum DBPool { + Sqlite(Pool), + Postgres(Pool), + Mysql(Pool) +} + #[tokio::main] async fn main() -> Result<(), sqlx::Error> { const CONFIG_PATH: &str = "/etc/flux-sftp/config.toml"; - let config: Config; + let config: Arc; match fs::read_to_string(CONFIG_PATH).await { Ok(toml) => { match toml::from_str::(&toml) { - Ok(c) => config = c, + Ok(c) => config = Arc::new(c), Err(e) => { println!("error parsing config file: {}\n please make sure config file is valid", e); return Ok(()) @@ -130,17 +155,19 @@ async fn main() -> Result<(), sqlx::Error> { } } - // let url = match &config.database { - // DBConfig::Sqlite { path } => format!("sqlite:{}", path), - // DBConfig::Postgres { host, port, user, password, dbname } => format!("postgres://{}:{}@{}:{}/{}", user, password, host, port, dbname), - // DBConfig::Mysql { host, port, user, password, dbname } => format!("mysql://{}:{}@{}:{}/{}", user, password, host, port, dbname), - // }; + let url = match &config.database { + DBConfig::Sqlite { path } => format!("sqlite:{}", path), + DBConfig::Postgres { host, port, user, password, dbname } => format!("postgres://{}:{}@{}:{}/{}", user, password, host, port, dbname), + DBConfig::Mysql { host, port, user, password, dbname } => format!("mysql://{}:{}@{}:{}/{}", user, password, host, port, dbname), + }; - - let pool = SqlitePoolOptions::new() - .max_connections(3) - .connect("sqlite:/home/rafayahmad/Stuff/Coding/Rust/flux-sftp/auth.db").await?; - let mut server = SftpServer { pool: Arc::new(pool) }; + let pool = match &config.database { + DBConfig::Sqlite { .. } => DBPool::Sqlite(SqlitePoolOptions::new().max_connections(3).connect(&url).await?), + DBConfig::Postgres { .. } => DBPool::Postgres(PgPoolOptions::new().max_connections(3).connect(&url).await?), + DBConfig::Mysql { .. } => DBPool::Mysql(MySqlPoolOptions::new().max_connections(3).connect(&url).await?) + }; + + let mut server = SftpServer { pool: Arc::new(pool), config: config.clone() }; let russh_config = russh::server::Config { auth_rejection_time: Duration::from_secs(3), @@ -151,6 +178,6 @@ async fn main() -> Result<(), sqlx::Error> { ..Default::default() }; - server.run_on_address(Arc::new(russh_config), (config.general.listen_address, config.general.port)).await.unwrap(); + server.run_on_address(Arc::new(russh_config), (&config.general.listen_address as &str, config.general.port)).await.unwrap(); Ok(()) }