diff --git a/Cargo.toml b/Cargo.toml index 73af883..c2f9d08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ tokio = "1.39.2" tokio-rustls = "0.26.0" webpki = "0.22.4" webpki-roots = "0.26.3" -x509-parser = "0.16.0" +x509-parser = { version = "0.16.0", features = ["verify"]} diff --git a/src/bin/allow_any_cert_client.rs b/src/bin/allow_any_cert_client.rs index 40e5625..0e3122f 100644 --- a/src/bin/allow_any_cert_client.rs +++ b/src/bin/allow_any_cert_client.rs @@ -64,7 +64,7 @@ impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { _: &[u8], _: rustls::pki_types::UnixTime, ) -> Result { - let ret_deserial = X509Certificate::from_der(&end_entity.iter().as_slice()); + let ret_deserial = X509Certificate::from_der(end_entity.iter().as_slice()); let x509 = match ret_deserial { Ok((_, x509)) => x509, _ => panic!("wtf"), @@ -73,7 +73,8 @@ impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { .subject() .iter_common_name() .next() - .and_then(|cn| cn.as_str().ok()).unwrap(); + .and_then(|cn| cn.as_str().ok()) + .unwrap(); info!( "Server Cert: CN: {}, CA: {}, serverName : {:?}", cn, @@ -83,7 +84,7 @@ impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { // end_entity for (idx, ica) in intermediates.iter().enumerate() { - let ret_deserial = X509Certificate::from_der(&ica.iter().as_slice()); + let ret_deserial = X509Certificate::from_der(ica.iter().as_slice()); let x509 = match ret_deserial { Ok((_, x509)) => x509, _ => continue, @@ -92,12 +93,8 @@ impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { .subject() .iter_common_name() .next() - .and_then(|cn| cn.as_str().ok()); - let cn = match cn { - Some(name) => name, - _ => "", - }; - + .and_then(|cn| cn.as_str().ok()) + .unwrap_or_default(); info!("[{idx}] CN: {}, CA: {}", cn, x509.is_ca()); } info!("verify cert done"); @@ -105,22 +102,21 @@ impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { } fn supported_verify_schemes(&self) -> Vec { - let mut ss = Vec::::new(); - ss.push(rustls::SignatureScheme::RSA_PKCS1_SHA1); - ss.push(rustls::SignatureScheme::ECDSA_SHA1_Legacy); - ss.push(rustls::SignatureScheme::RSA_PKCS1_SHA256); - ss.push(rustls::SignatureScheme::ECDSA_NISTP256_SHA256); - ss.push(rustls::SignatureScheme::RSA_PKCS1_SHA384); - ss.push(rustls::SignatureScheme::ECDSA_NISTP384_SHA384); - ss.push(rustls::SignatureScheme::RSA_PKCS1_SHA512); - ss.push(rustls::SignatureScheme::ECDSA_NISTP521_SHA512); - ss.push(rustls::SignatureScheme::RSA_PSS_SHA256); - ss.push(rustls::SignatureScheme::RSA_PSS_SHA384); - ss.push(rustls::SignatureScheme::RSA_PSS_SHA512); - ss.push(rustls::SignatureScheme::ED25519); - ss.push(rustls::SignatureScheme::ED448); - - ss + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA1, + rustls::SignatureScheme::ECDSA_SHA1_Legacy, + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP521_SHA512, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::ED448, + ] } } @@ -142,15 +138,15 @@ fn main() -> Result<(), Box> { let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name)?; let mut sock = std::net::TcpStream::connect("localhost:10080")?; let mut tls_conn = rustls::Stream::new(&mut conn, &mut sock); - let mut rbuf : [u8; 1024] = [0; 1024]; + let mut rbuf: [u8; 1024] = [0; 1024]; - let wstring = "Hello Rust!"; + let wstring = "Hello Rust!"; let wsize = match tls_conn.write(wstring.as_bytes()) { Ok(size) => size, Err(ec) => { error!("Write Error: {}", ec.kind().to_string()); usize::MAX - }, + } }; info!("Send Data ({wsize}): {wstring}"); @@ -159,7 +155,7 @@ fn main() -> Result<(), Box> { Err(ec) => { error!("Read Error: {}", ec.kind().to_string()); 0 - }, + } }; let utf8string = String::from_utf8(rbuf[0..rsize].to_vec()).expect("could not encoded utf8"); diff --git a/src/bin/client.rs b/src/bin/client.rs index 8a4151d..cbd116c 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -189,6 +189,7 @@ fn load_use_certificate(crt_name: &str, store: &mut RootCertStore) { let deserialized_cert = X509Certificate::from_der(cert.iter().as_slice()); match deserialized_cert { Ok((_, x509)) => { + x509.verify_signature(public_key) let cn = x509 .subject() .iter_common_name() @@ -259,6 +260,12 @@ fn main() -> Result<(), Box> { let mut tls_conn = rustls::Stream::new(&mut conn, &mut sock); let mut rbuf: [u8; 1024] = [0; 1024]; + if let Err(e) = tls_conn.conn.complete_io(tls_conn.sock) { + error!("TLS handshake failure - {}:{}", e.kind(), e.to_string()); + tls_conn.sock.shutdown(std::net::Shutdown::Both).expect("stream shutdown failure"); + return Err(Box::new(e)) + } + let wstring = "Hello Rust!"; let wsize = match tls_conn.write(wstring.as_bytes()) { Ok(size) => size, diff --git a/src/main.rs b/src/main.rs index d601d67..211290f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ extern crate pretty_env_logger; #[macro_use] extern crate log; use clap::{Parser, ValueEnum}; +use rustls::server::danger::ClientCertVerifier; use std::error::Error as StdError; use std::fs::File; use std::io::{BufReader, Read, Write}; @@ -11,12 +12,15 @@ use std::sync::Arc; #[derive(Parser)] #[command(version = "1.0")] struct Cli { - #[arg(short, long, value_name = "Certificate")] /// fullchain certificate name. format PEM + #[arg(short, long, value_name = "Certificate")] cert: String, - #[arg(short, long, value_name = "PrivateKey")] /// server privateKey name. format PEM + #[arg(short, long, value_name = "PrivateKey")] key: String, + /// mTLS Mode + #[arg(short = 'm')] + mtls: Option, /// set log level #[arg(short, long, value_name = "Log Level", value_enum)] level: Option, @@ -52,12 +56,6 @@ impl Level { } } -#[derive(Copy, Clone)] -enum KeyPath { - Cert = 0, - Pkey = 1, -} - fn initialize_log(options: &Cli) { let level = match options.level { Some(level) => level.to_level_filter(), @@ -88,39 +86,119 @@ fn initialize_log(options: &Cli) { info!("Hi -"); } -fn parse_args() -> Vec { +#[derive(Debug)] +struct CliCertVerifier; +// { + +// supported: Arc, +// } + +// impl CliCertVerifier { +// fn new(roots: Arc) -> Self { +// Self { +// supported : Arc::clone(crypto::CryptoProvider::get_default_or_install_from_crate_features()), +// } +// } +// } + +impl ClientCertVerifier for CliCertVerifier { + fn offer_client_auth(&self) -> bool { + true + } + + fn client_auth_mandatory(&self) -> bool { + self.offer_client_auth() + } + + fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { + todo!() + } + + fn verify_client_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _now: rustls::pki_types::UnixTime, + ) -> Result { + todo!() + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + // NOTE: This should be in priority order, with the most preferred first. + // enum에서 복사 했기 때문에 우선순위가 반영되어있지 않음 + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA1, + rustls::SignatureScheme::ECDSA_SHA1_Legacy, + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP521_SHA512, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::ED448, + ] + } +} + +fn main() -> Result<(), Box> { let cli = Cli::parse(); initialize_log(&cli); - vec![cli.cert.clone(), cli.key.clone()] -} + let certs = rustls_pemfile::certs(&mut BufReader::new(&mut File::open(cli.cert)?)) + .collect::, _>>()?; + let private_key = + rustls_pemfile::private_key(&mut BufReader::new(&mut File::open(cli.key)?))?.unwrap(); -fn main() -> Result<(), Box> { - let key_path = parse_args(); + let use_mtls = match cli.mtls { + Some(mtls) => mtls, + None => false, + }; - let certs = rustls_pemfile::certs(&mut BufReader::new(&mut File::open( - &key_path[KeyPath::Cert as usize][0..], - )?)) - .collect::, _>>()?; - let private_key = rustls_pemfile::private_key(&mut BufReader::new(&mut File::open( - &key_path[KeyPath::Pkey as usize][0..], - )?))? - .unwrap(); - let config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, private_key)?; + let config = if use_mtls { + rustls::ServerConfig::builder() + .with_client_cert_verifier(Arc::new(CliCertVerifier)) + } else { + rustls::ServerConfig::builder() + .with_no_client_auth() + }.with_single_cert(certs, private_key)?; let listener = TcpListener::bind(format!("[::]:{}", 10080)).unwrap(); trace!("Bound TCP Server"); let mut rbuf = [0u8; 1024 * 64]; // 64kb for r_tcp in listener.incoming() { + debug!("wait Client"); let mut stream = r_tcp.expect("Accept Failure"); stream.set_nodelay(true).expect("Set no delay Failure"); let mut tls_conn = rustls::ServerConnection::new(Arc::new(config.clone())).expect("TLS Server Error"); - tls_conn.complete_io(&mut stream)?; + if let Err(e) = tls_conn.complete_io(&mut stream) { + error!("TLS handshake failure - {}:{}", e.kind(), e.to_string()); + continue; + } debug!("handshake done!?"); let rsize = match tls_conn.reader().read(&mut rbuf) { @@ -139,7 +217,7 @@ fn main() -> Result<(), Box> { .writer() .write_fmt(format_args!("Echo, Client say: {}", utf8string))?; debug!("Send Data"); - tls_conn.complete_io(&mut stream)?; + tls_conn.complete_io(&mut stream).expect("Send Flush Error"); // -> drop tls_conn }