mod authenticator;
mod supplicant;
use crate::key::exchange;
use crate::key::gtk::{Gtk, GtkProvider};
use crate::key::igtk::{Igtk, IgtkProvider};
use crate::key::ptk::Ptk;
use crate::nonce::NonceReader;
use crate::rsna::{Dot11VerifiedKeyFrame, NegotiatedProtection, Role, UpdateSink};
use crate::{rsn_ensure, Error, ProtectionInfo};
use ieee80211::MacAddr;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use wlan_common::ie::rsn::cipher::Cipher;
use wlan_common::ie::rsn::rsne::Rsne;
use wlan_common::ie::rsn::suite_filter::DEFAULT_GROUP_MGMT_CIPHER;
use wlan_statemachine::StateMachine;
use zerocopy::SplitByteSlice;
#[derive(Debug, PartialEq)]
pub enum MessageNumber {
Message1 = 1,
Message2 = 2,
Message3 = 3,
Message4 = 4,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub struct AuthenticatorKeyReplayCounter(u64);
impl Deref for AuthenticatorKeyReplayCounter {
type Target = u64;
fn deref(&self) -> &u64 {
&self.0
}
}
impl AuthenticatorKeyReplayCounter {
pub fn next_after(key_replay_counter: u64) -> Self {
Self(key_replay_counter + 1)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SupplicantKeyReplayCounter(u64);
impl Deref for SupplicantKeyReplayCounter {
type Target = u64;
fn deref(&self) -> &u64 {
&self.0
}
}
impl From<u64> for SupplicantKeyReplayCounter {
fn from(a: u64) -> Self {
Self(a)
}
}
pub struct FourwayHandshakeFrame<B: SplitByteSlice>(Dot11VerifiedKeyFrame<B>);
impl<B: SplitByteSlice> FourwayHandshakeFrame<B> {
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn from_verified(
frame: Dot11VerifiedKeyFrame<B>,
role: Role,
nonce: Option<&[u8]>,
) -> Result<FourwayHandshakeFrame<B>, Error> {
let raw_frame = frame.unsafe_get_raw();
let msg_no = message_number(raw_frame);
match role {
Role::Authenticator => match msg_no {
MessageNumber::Message2 | MessageNumber::Message4 => {}
_ => return Err(Error::UnexpectedHandshakeMessage(msg_no.into()).into()),
},
Role::Supplicant => match msg_no {
MessageNumber::Message1 | MessageNumber::Message3 => {}
_ => return Err(Error::UnexpectedHandshakeMessage(msg_no.into()).into()),
},
};
match msg_no {
MessageNumber::Message1 => validate_message_1(raw_frame),
MessageNumber::Message2 => validate_message_2(raw_frame),
MessageNumber::Message3 => validate_message_3(raw_frame, nonce),
MessageNumber::Message4 => validate_message_4(raw_frame),
}?;
Ok(FourwayHandshakeFrame(frame))
}
pub fn get(self) -> Dot11VerifiedKeyFrame<B> {
self.0
}
fn message_number(&self) -> MessageNumber {
message_number(self.unsafe_get_raw())
}
}
impl<B: SplitByteSlice> std::ops::Deref for FourwayHandshakeFrame<B> {
type Target = Dot11VerifiedKeyFrame<B>;
fn deref(&self) -> &Dot11VerifiedKeyFrame<B> {
&self.0
}
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn get_group_mgmt_cipher(
s_protection: &ProtectionInfo,
a_protection: &ProtectionInfo,
) -> Result<Option<Cipher>, Error> {
match (s_protection, a_protection) {
(
ProtectionInfo::Rsne(Rsne {
rsn_capabilities: Some(s_rsn_capabilities),
group_mgmt_cipher_suite: s_group_mgmt_cipher_suite,
..
}),
ProtectionInfo::Rsne(Rsne {
rsn_capabilities: Some(a_rsn_capabilities),
group_mgmt_cipher_suite: a_group_mgmt_cipher_suite,
..
}),
) => {
if !s_rsn_capabilities.mgmt_frame_protection_cap()
&& s_rsn_capabilities.mgmt_frame_protection_req()
{
return Err(Error::InvalidClientMgmtFrameProtectionCapabilityBit);
}
if !a_rsn_capabilities.mgmt_frame_protection_cap()
&& a_rsn_capabilities.mgmt_frame_protection_req()
{
return Err(Error::InvalidApMgmtFrameProtectionCapabilityBit);
}
if s_rsn_capabilities.mgmt_frame_protection_req()
&& !a_rsn_capabilities.mgmt_frame_protection_cap()
{
return Err(Error::MgmtFrameProtectionRequiredByClient);
}
if !s_rsn_capabilities.mgmt_frame_protection_cap()
&& a_rsn_capabilities.mgmt_frame_protection_req()
{
return Err(Error::MgmtFrameProtectionRequiredByAp);
}
if s_rsn_capabilities.mgmt_frame_protection_cap()
&& a_rsn_capabilities.mgmt_frame_protection_cap()
{
let s_group_mgmt_cipher_suite =
s_group_mgmt_cipher_suite.unwrap_or(DEFAULT_GROUP_MGMT_CIPHER);
let a_group_mgmt_cipher_suite =
a_group_mgmt_cipher_suite.unwrap_or(DEFAULT_GROUP_MGMT_CIPHER);
if s_group_mgmt_cipher_suite != a_group_mgmt_cipher_suite {
return Err(Error::GroupMgmtCipherMismatch(
s_group_mgmt_cipher_suite,
a_group_mgmt_cipher_suite,
));
}
return Ok(Some(s_group_mgmt_cipher_suite));
}
}
(_, ProtectionInfo::Rsne(Rsne { rsn_capabilities: Some(a_rsn_capabilities), .. })) => {
if a_rsn_capabilities.mgmt_frame_protection_req() {
return Err(Error::MgmtFrameProtectionRequiredByAp);
}
}
(ProtectionInfo::Rsne(Rsne { rsn_capabilities: Some(s_rsn_capabilities), .. }), _) => {
if s_rsn_capabilities.mgmt_frame_protection_req() {
return Err(Error::MgmtFrameProtectionRequiredByClient);
}
}
(ProtectionInfo::Rsne(Rsne { rsn_capabilities: None, .. }), _)
| (ProtectionInfo::LegacyWpa(_), _) => {}
}
Ok(None)
}
#[derive(Debug, Clone)]
pub struct Config {
pub role: Role,
pub s_addr: MacAddr,
pub s_protection: ProtectionInfo,
pub a_addr: MacAddr,
pub a_protection: ProtectionInfo,
pub nonce_rdr: Arc<NonceReader>,
pub gtk_provider: Option<Arc<Mutex<GtkProvider>>>,
pub igtk_provider: Option<Arc<Mutex<IgtkProvider>>>,
_private: (),
}
impl Config {
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn new(
role: Role,
s_addr: MacAddr,
s_protection: ProtectionInfo,
a_addr: MacAddr,
a_protection: ProtectionInfo,
nonce_rdr: Arc<NonceReader>,
gtk_provider: Option<Arc<Mutex<GtkProvider>>>,
igtk_provider: Option<Arc<Mutex<IgtkProvider>>>,
) -> Result<Config, Error> {
match (&s_protection, &a_protection) {
(ProtectionInfo::Rsne(s_rsne), ProtectionInfo::Rsne(a_rsne)) => {
match s_rsne.is_valid_subset_of(a_rsne) {
Ok(true) => {}
Ok(false) => {
return Err(Error::RsneInvalidSubset(s_rsne.clone(), a_rsne.clone()))
}
Err(e) => return Err(e.into()),
};
}
(_, ProtectionInfo::LegacyWpa(_)) => {}
(ProtectionInfo::LegacyWpa(_), _) => {}
}
match NegotiatedProtection::from_protection(&s_protection) {
Ok(negotiated_protection) => negotiated_protection,
Err(e) => return Err(Error::InvalidSupplicantProtection(format!("{:?}", e))),
};
if role == Role::Authenticator {
rsn_ensure!(gtk_provider.is_some(), Error::MissingGtkProvider);
match get_group_mgmt_cipher(&s_protection, &a_protection)? {
Some(group_mgmt_cipher) => match igtk_provider.as_ref() {
None => return Err(Error::MissingIgtkProvider),
Some(igtk_provider) => {
let igtk_provider_cipher = igtk_provider.lock().unwrap().cipher();
rsn_ensure!(
group_mgmt_cipher == igtk_provider_cipher,
Error::WrongIgtkProviderCipher(group_mgmt_cipher, igtk_provider_cipher),
);
}
},
None => {}
}
}
Ok(Config {
role,
s_addr,
s_protection,
a_addr,
a_protection,
nonce_rdr,
gtk_provider,
igtk_provider,
_private: (),
})
}
}
impl PartialEq for Config {
fn eq(&self, other: &Config) -> bool {
self.role == other.role
&& self.s_addr == other.s_addr
&& self.s_protection == other.s_protection
&& self.a_addr == other.a_addr
&& self.a_protection == other.a_protection
}
}
#[derive(Debug, PartialEq)]
pub enum Fourway {
Authenticator(StateMachine<authenticator::State>),
Supplicant(StateMachine<supplicant::State>),
}
impl Fourway {
pub fn new(cfg: Config, pmk: Vec<u8>) -> Result<Fourway, anyhow::Error> {
let fourway = match &cfg.role {
Role::Supplicant => {
let state = supplicant::new(cfg, pmk);
Fourway::Supplicant(StateMachine::new(state))
}
Role::Authenticator => {
let state = authenticator::new(cfg, pmk);
Fourway::Authenticator(StateMachine::new(state))
}
};
Ok(fourway)
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn on_eapol_key_frame<B: SplitByteSlice>(
&mut self,
update_sink: &mut UpdateSink,
frame: Dot11VerifiedKeyFrame<B>,
) -> Result<(), Error> {
match self {
Fourway::Authenticator(state_machine) => {
let frame = FourwayHandshakeFrame::from_verified(frame, Role::Authenticator, None)?;
state_machine.replace_state(|state| state.on_eapol_key_frame(update_sink, frame));
Ok(())
}
Fourway::Supplicant(state_machine) => {
let anonce = state_machine.as_ref().anonce();
let frame = FourwayHandshakeFrame::from_verified(frame, Role::Supplicant, anonce)?;
state_machine.replace_state(|state| state.on_eapol_key_frame(update_sink, frame));
Ok(())
}
}
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn on_rsna_response_timeout(&self) -> Result<(), Error> {
match self {
Fourway::Authenticator(state_machine) => state_machine.on_rsna_response_timeout(),
Fourway::Supplicant(state_machine) => state_machine.on_rsna_response_timeout(),
}
}
pub fn ptk(&self) -> Option<Ptk> {
match self {
Fourway::Authenticator(state_machine) => state_machine.as_ref().ptk(),
Fourway::Supplicant(state_machine) => state_machine.as_ref().ptk(),
}
}
pub fn gtk(&self) -> Option<Gtk> {
match self {
Fourway::Authenticator(state_machine) => state_machine.as_ref().gtk(),
Fourway::Supplicant(state_machine) => state_machine.as_ref().gtk(),
}
}
pub fn igtk(&self) -> Option<Igtk> {
match self {
Fourway::Authenticator(state_machine) => state_machine.as_ref().igtk(),
Fourway::Supplicant(state_machine) => state_machine.as_ref().igtk(),
}
}
#[cfg(test)]
pub fn get_config(&self) -> Config {
match self {
Fourway::Supplicant(state_machine) => match state_machine.as_ref() {
supplicant::State::AwaitingMsg1 { cfg, .. }
| supplicant::State::AwaitingMsg3 { cfg, .. }
| supplicant::State::KeysInstalled { cfg, .. } => cfg.clone(),
},
Fourway::Authenticator(state_machine) => match state_machine.as_ref() {
authenticator::State::Idle { cfg, .. }
| authenticator::State::AwaitingMsg2 { cfg, .. }
| authenticator::State::AwaitingMsg4 { cfg, .. }
| authenticator::State::KeysInstalled { cfg, .. } => cfg.clone(),
},
}
}
pub fn destroy(self) -> exchange::Config {
let cfg = match self {
Fourway::Supplicant(state_machine) => state_machine.into_state().destroy(),
Fourway::Authenticator(state_machine) => state_machine.into_state().destroy(),
};
exchange::Config::FourWayHandshake(cfg)
}
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
fn validate_message_1<B: SplitByteSlice>(frame: &eapol::KeyFrameRx<B>) -> Result<(), Error> {
let key_info = frame.key_frame_fields.key_info();
rsn_ensure!(!key_info.install(), Error::InvalidInstallBitValue(MessageNumber::Message1.into()));
rsn_ensure!(key_info.key_ack(), Error::InvalidKeyAckBitValue(MessageNumber::Message1.into()));
rsn_ensure!(!key_info.key_mic(), Error::InvalidKeyMicBitValue(MessageNumber::Message1.into()));
rsn_ensure!(!key_info.secure(), Error::InvalidSecureBitValue(MessageNumber::Message1.into()));
rsn_ensure!(!key_info.error(), Error::InvalidErrorBitValue(MessageNumber::Message1.into()));
rsn_ensure!(!key_info.request(), Error::InvalidRequestBitValue(MessageNumber::Message1.into()));
rsn_ensure!(
!key_info.encrypted_key_data(),
Error::InvalidEncryptedKeyDataBitValue(MessageNumber::Message1.into())
);
rsn_ensure!(
!is_zero(&frame.key_frame_fields.key_nonce[..]),
Error::InvalidNonce(MessageNumber::Message1.into())
);
rsn_ensure!(
is_zero(&frame.key_frame_fields.key_iv[..]),
Error::InvalidIv(frame.eapol_fields.version, MessageNumber::Message1.into())
);
rsn_ensure!(
frame.key_frame_fields.key_rsc.to_native() == 0,
Error::InvalidRsc(MessageNumber::Message1.into())
);
Ok(())
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
fn validate_message_2<B: SplitByteSlice>(frame: &eapol::KeyFrameRx<B>) -> Result<(), Error> {
let key_info = frame.key_frame_fields.key_info();
rsn_ensure!(!key_info.install(), Error::InvalidInstallBitValue(MessageNumber::Message2.into()));
rsn_ensure!(!key_info.key_ack(), Error::InvalidKeyAckBitValue(MessageNumber::Message2.into()));
rsn_ensure!(key_info.key_mic(), Error::InvalidKeyMicBitValue(MessageNumber::Message2.into()));
rsn_ensure!(!key_info.secure(), Error::InvalidSecureBitValue(MessageNumber::Message2.into()));
rsn_ensure!(!key_info.error(), Error::InvalidErrorBitValue(MessageNumber::Message2.into()));
rsn_ensure!(!key_info.request(), Error::InvalidRequestBitValue(MessageNumber::Message2.into()));
rsn_ensure!(
!key_info.encrypted_key_data(),
Error::InvalidEncryptedKeyDataBitValue(MessageNumber::Message2.into())
);
rsn_ensure!(
!is_zero(&frame.key_frame_fields.key_nonce[..]),
Error::InvalidNonce(MessageNumber::Message2.into())
);
rsn_ensure!(
is_zero(&frame.key_frame_fields.key_iv[..]),
Error::InvalidIv(frame.eapol_fields.version, MessageNumber::Message2.into())
);
rsn_ensure!(
frame.key_frame_fields.key_rsc.to_native() == 0,
Error::InvalidRsc(MessageNumber::Message2.into())
);
Ok(())
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
fn validate_message_3<B: SplitByteSlice>(
frame: &eapol::KeyFrameRx<B>,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let key_info = frame.key_frame_fields.key_info();
rsn_ensure!(key_info.install(), Error::InvalidInstallBitValue(MessageNumber::Message3.into()));
rsn_ensure!(key_info.key_ack(), Error::InvalidKeyAckBitValue(MessageNumber::Message3.into()));
if frame.key_frame_fields.descriptor_type != eapol::KeyDescriptor::LEGACY_WPA1 {
rsn_ensure!(
key_info.key_mic(),
Error::InvalidKeyMicBitValue(MessageNumber::Message3.into())
);
rsn_ensure!(
key_info.secure(),
Error::InvalidSecureBitValue(MessageNumber::Message3.into())
);
rsn_ensure!(
key_info.encrypted_key_data(),
Error::InvalidEncryptedKeyDataBitValue(MessageNumber::Message3.into())
);
}
rsn_ensure!(!key_info.error(), Error::InvalidErrorBitValue(MessageNumber::Message3.into()));
rsn_ensure!(!key_info.request(), Error::InvalidRequestBitValue(MessageNumber::Message3.into()));
if let Some(nonce) = nonce {
rsn_ensure!(
!is_zero(&frame.key_frame_fields.key_nonce[..])
&& &frame.key_frame_fields.key_nonce[..] == nonce,
Error::InvalidNonce(MessageNumber::Message3.into())
);
}
rsn_ensure!(
frame.eapol_fields.version < eapol::ProtocolVersion::IEEE802DOT1X2010
|| is_zero(&frame.key_frame_fields.key_iv[..]),
Error::InvalidIv(frame.eapol_fields.version, MessageNumber::Message3.into())
);
rsn_ensure!(frame.key_data.len() != 0, Error::EmptyKeyData(MessageNumber::Message3.into()));
Ok(())
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
fn validate_message_4<B: SplitByteSlice>(frame: &eapol::KeyFrameRx<B>) -> Result<(), Error> {
let key_info = frame.key_frame_fields.key_info();
rsn_ensure!(!key_info.install(), Error::InvalidInstallBitValue(MessageNumber::Message4.into()));
rsn_ensure!(!key_info.key_ack(), Error::InvalidKeyAckBitValue(MessageNumber::Message4.into()));
rsn_ensure!(key_info.key_mic(), Error::InvalidKeyMicBitValue(MessageNumber::Message4.into()));
rsn_ensure!(key_info.secure(), Error::InvalidSecureBitValue(MessageNumber::Message4.into()));
rsn_ensure!(!key_info.error(), Error::InvalidErrorBitValue(MessageNumber::Message4.into()));
rsn_ensure!(!key_info.request(), Error::InvalidRequestBitValue(MessageNumber::Message4.into()));
rsn_ensure!(
!key_info.encrypted_key_data(),
Error::InvalidEncryptedKeyDataBitValue(MessageNumber::Message4.into())
);
rsn_ensure!(
is_zero(&frame.key_frame_fields.key_iv[..]),
Error::InvalidIv(frame.eapol_fields.version, MessageNumber::Message4.into())
);
rsn_ensure!(
frame.key_frame_fields.key_rsc.to_native() == 0,
Error::InvalidRsc(MessageNumber::Message4.into())
);
Ok(())
}
fn message_number<B: SplitByteSlice>(rx_frame: &eapol::KeyFrameRx<B>) -> MessageNumber {
if rx_frame.key_frame_fields.key_info().key_ack() {
if rx_frame.key_frame_fields.key_info().install() {
MessageNumber::Message3
} else {
MessageNumber::Message1
}
} else {
if rx_frame.key_frame_fields.key_info().secure() {
MessageNumber::Message4
} else {
MessageNumber::Message2
}
}
}
fn is_zero(slice: &[u8]) -> bool {
slice.iter().all(|&x| x == 0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rsna::{test_util, SecAssocUpdate};
use crate::rsne::RsnCapabilities;
use wlan_common::ie::rsn::cipher::{CIPHER_BIP_CMAC_128, CIPHER_BIP_CMAC_256};
use wlan_common::ie::wpa::fake_wpa_ies::fake_deprecated_wpa1_vendor_ie;
#[test]
fn correct_value_returned_by_authenticator_key_replay_counter() {
let key_replay_counter = AuthenticatorKeyReplayCounter(5);
assert_eq!(*key_replay_counter, 5);
}
#[test]
fn correct_value_returned_by_supplicant_key_replay_counter() {
let key_replay_counter = SupplicantKeyReplayCounter(5);
assert_eq!(*key_replay_counter, 5);
}
#[test]
fn next_replay_counter_is_next_integer() {
let s_key_replay_counter = SupplicantKeyReplayCounter(5);
let next_key_replay_counter =
AuthenticatorKeyReplayCounter::next_after(*s_key_replay_counter);
assert_eq!(*next_key_replay_counter, 6);
}
#[test]
fn test_supplicant_with_authenticator() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
assert_eq!(msg1.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let (msg2, _) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
assert_eq!(msg2.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
assert_eq!(msg3.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let (msg4, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3.keyframe(), 12.into());
assert_eq!(msg4.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let (a_ptk, a_gtk) = env.send_msg4_to_authenticator(msg4.keyframe(), 13.into());
assert_eq!(s_ptk, a_ptk);
assert_eq!(s_gtk, a_gtk);
assert_eq!(s_gtk.key_id(), 1);
assert_eq!(s_gtk.key_rsc(), 3);
}
#[test]
fn test_wpa3_handshake_generates_igtk_real_authenticator() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa3, 1, 754);
let msg1 = env.initiate(11.into());
assert_eq!(msg1.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let (msg2, _) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
assert_eq!(msg2.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
assert_eq!(msg3.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let (msg4, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3.keyframe(), 12.into());
assert_eq!(msg4.keyframe().eapol_fields.version, eapol::ProtocolVersion::IEEE802DOT1X2004);
let (a_ptk, a_gtk) = env.send_msg4_to_authenticator(msg4.keyframe(), 13.into());
assert_eq!(s_ptk, a_ptk);
assert_eq!(s_gtk, a_gtk);
assert_eq!(s_gtk.key_id(), 1);
assert_eq!(s_gtk.key_rsc(), 754);
assert!(env.supplicant.igtk().is_some());
assert_eq!(env.supplicant.igtk(), env.authenticator.igtk());
}
fn run_wpa3_handshake_mock_authenticator(
gtk: &[u8],
igtk: Option<&[u8]>,
) -> (Ptk, Vec<SecAssocUpdate>) {
let anonce = [0xab; 32];
let mut supplicant = test_util::make_handshake(
test_util::HandshakeKind::Wpa3,
super::Role::Supplicant,
1,
3,
);
let msg1_buf = test_util::get_wpa3_4whs_msg1(&anonce[..]);
let msg1 = msg1_buf.keyframe();
let updates = test_util::send_msg_to_fourway(&mut supplicant, msg1, 0.into());
let msg2 = test_util::expect_eapol_resp(&updates[..]);
let a_ptk =
test_util::get_wpa3_ptk(&anonce[..], &msg2.keyframe().key_frame_fields.key_nonce[..]);
let msg3_buf = &test_util::get_wpa3_4whs_msg3(&a_ptk, &anonce[..], >k[..], igtk);
let msg3 = msg3_buf.keyframe();
(a_ptk, test_util::send_msg_to_fourway(&mut supplicant, msg3, 0.into()))
}
#[test]
fn test_wpa3_handshake_generates_igtk_mock_authenticator() {
let gtk = [0xbb; 32];
let igtk = [0xcc; 32];
let (ptk, updates) = run_wpa3_handshake_mock_authenticator(>k[..], Some(&igtk[..]));
test_util::expect_eapol_resp(&updates[..]);
let s_ptk = test_util::expect_reported_ptk(&updates[..]);
let s_gtk = test_util::expect_reported_gtk(&updates[..]);
let s_igtk = test_util::expect_reported_igtk(&updates[..]);
assert_eq!(s_ptk, ptk);
assert_eq!(&s_gtk.bytes[..], >k[..]);
assert_eq!(&s_igtk.igtk[..], &igtk[..]);
}
#[test]
fn test_wpa3_handshake_requires_igtk() {
let gtk = [0xbb; 32];
let (_ptk, updates) = run_wpa3_handshake_mock_authenticator(>k[..], None);
assert!(
test_util::get_eapol_resp(&updates[..]).is_none(),
"WPA3 should not send EAPOL msg4 without IGTK"
);
}
#[test]
fn test_wpa1_handshake() {
let pmk = test_util::get_pmk();
let cfg = test_util::make_wpa1_fourway_cfg();
let mut supplicant = Fourway::new(cfg, pmk).expect("error while creating 4-Way Handshake");
let anonce = [0xab; 32];
let msg1_buf = test_util::get_wpa1_4whs_msg1(&anonce[..]);
let msg1 = msg1_buf.keyframe();
let updates = test_util::send_msg_to_fourway(&mut supplicant, msg1, 0.into());
let msg2 = test_util::expect_eapol_resp(&updates[..]);
let a_ptk =
test_util::get_wpa1_ptk(&anonce[..], &msg2.keyframe().key_frame_fields.key_nonce[..]);
let msg3_buf = &test_util::get_wpa1_4whs_msg3(&a_ptk, &anonce[..]);
let msg3 = msg3_buf.keyframe();
let updates = test_util::send_msg_to_fourway(&mut supplicant, msg3, 0.into());
test_util::expect_eapol_resp(&updates[..]);
let s_ptk = test_util::expect_reported_ptk(&updates[..]);
assert_eq!(s_ptk, a_ptk);
}
#[test]
fn test_supplicant_replay_msg3() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
let (msg2, _) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
let (_, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3.keyframe(), 12.into());
let mut update_sink = UpdateSink::default();
env.send_msg3_to_supplicant_capture_updates(msg3.keyframe(), 12.into(), &mut update_sink);
let msg4 = test_util::expect_eapol_resp(&update_sink[..]);
for update in update_sink {
if let SecAssocUpdate::Key(_) = update {
panic!("reinstalled key");
}
}
let (a_ptk, a_gtk) = env.send_msg4_to_authenticator(msg4.keyframe(), 13.into());
assert_eq!(s_ptk, a_ptk);
assert_eq!(s_gtk, a_gtk);
}
#[test]
fn test_supplicant_replay_msg3_different_gtk() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
let anonce = msg1.keyframe().key_frame_fields.key_nonce.clone();
let (msg2, _) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
let (_, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3.keyframe(), 12.into());
let mut other_gtk = s_gtk.bytes.clone();
other_gtk[0] ^= 0xFF;
let msg3 = test_util::get_wpa2_4whs_msg3(&s_ptk, &anonce[..], &other_gtk[..], |msg3| {
msg3.key_frame_fields.key_replay_counter.set_from_native(42);
});
let mut update_sink = UpdateSink::default();
env.send_msg3_to_supplicant_capture_updates(msg3.keyframe(), 13.into(), &mut update_sink);
assert_eq!(update_sink.len(), 0);
}
#[test]
fn test_random_iv_msg1_v1() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(1.into());
let mut buf = vec![];
let mut msg1 = msg1.copy_keyframe_mut(&mut buf);
msg1.eapol_fields.version = eapol::ProtocolVersion::IEEE802DOT1X2001;
msg1.key_frame_fields.key_iv = [0xFFu8; 16];
env.send_msg1_to_supplicant_expect_err(msg1, 1.into());
}
#[test]
fn test_random_iv_msg1_v2() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(1.into());
let mut buf = vec![];
let mut msg1 = msg1.copy_keyframe_mut(&mut buf);
msg1.eapol_fields.version = eapol::ProtocolVersion::IEEE802DOT1X2004;
msg1.key_frame_fields.key_iv = [0xFFu8; 16];
env.send_msg1_to_supplicant_expect_err(msg1, 1.into());
}
#[test]
fn test_random_iv_msg3_v2001() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
let mut buf = vec![];
let mut msg3 = msg3.copy_keyframe_mut(&mut buf);
msg3.eapol_fields.version = eapol::ProtocolVersion::IEEE802DOT1X2001;
msg3.key_frame_fields.key_iv = [0xFFu8; 16];
env.finalize_key_frame(&mut msg3, Some(ptk.kck()));
let (msg4, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3, 12.into());
let (a_ptk, a_gtk) = env.send_msg4_to_authenticator(msg4.keyframe(), 13.into());
assert_eq!(s_ptk, a_ptk);
assert_eq!(s_gtk, a_gtk);
}
#[test]
fn test_random_iv_msg3_v2004() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
let mut buf = vec![];
let mut msg3 = msg3.copy_keyframe_mut(&mut buf);
msg3.eapol_fields.version = eapol::ProtocolVersion::IEEE802DOT1X2004;
msg3.key_frame_fields.key_iv = [0xFFu8; 16];
env.finalize_key_frame(&mut msg3, Some(ptk.kck()));
let (msg4, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3, 12.into());
let (a_ptk, a_gtk) = env.send_msg4_to_authenticator(msg4.keyframe(), 13.into());
assert_eq!(s_ptk, a_ptk);
assert_eq!(s_gtk, a_gtk);
}
#[test]
fn test_zeroed_iv_msg3_v2004() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
let mut buf = vec![];
let mut msg3 = msg3.copy_keyframe_mut(&mut buf);
msg3.eapol_fields.version = eapol::ProtocolVersion::IEEE802DOT1X2004;
msg3.key_frame_fields.key_iv = [0u8; 16];
env.finalize_key_frame(&mut msg3, Some(ptk.kck()));
let (msg4, s_ptk, s_gtk) = env.send_msg3_to_supplicant(msg3, 12.into());
let (a_ptk, a_gtk) = env.send_msg4_to_authenticator(msg4.keyframe(), 13.into());
assert_eq!(s_ptk, a_ptk);
assert_eq!(s_gtk, a_gtk);
}
#[test]
fn test_random_iv_msg3_v2010() {
let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
let msg1 = env.initiate(11.into());
let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
let msg3 = env.send_msg2_to_authenticator(msg2.keyframe(), 12.into());
let mut buf = vec![];
let mut msg3 = msg3.copy_keyframe_mut(&mut buf);
msg3.eapol_fields.version = eapol::ProtocolVersion::IEEE802DOT1X2010;
msg3.key_frame_fields.key_iv = [0xFFu8; 16];
env.finalize_key_frame(&mut msg3, Some(ptk.kck()));
env.send_msg3_to_supplicant_expect_err(msg3, 12.into());
}
fn make_protection_info_with_mfp_parameters(
mfp_bits: Option<(bool, bool)>, group_mgmt_cipher_suite: Option<Cipher>,
) -> ProtectionInfo {
ProtectionInfo::Rsne(Rsne {
rsn_capabilities: match mfp_bits {
None => None,
Some((mfpc, mfpr)) => Some(
RsnCapabilities(0)
.with_mgmt_frame_protection_cap(mfpc)
.with_mgmt_frame_protection_req(mfpr),
),
},
group_mgmt_cipher_suite,
..Default::default()
})
}
fn check_rsne_get_group_mgmt_cipher(
s_mfp_bits: Option<(bool, bool)>, s_cipher: Option<Cipher>,
a_mfp_bits: Option<(bool, bool)>, a_cipher: Option<Cipher>,
expected_result: Result<Option<Cipher>, Error>,
) {
let s_protection_info = make_protection_info_with_mfp_parameters(s_mfp_bits, s_cipher);
let a_protection_info = make_protection_info_with_mfp_parameters(a_mfp_bits, a_cipher);
assert_eq!(get_group_mgmt_cipher(&s_protection_info, &a_protection_info), expected_result);
}
#[test]
fn test_get_group_mgmt_cipher() {
assert!(
CIPHER_BIP_CMAC_256 != DEFAULT_GROUP_MGMT_CIPHER,
"default group mgmt cipher is CIPHER_BIP_CMAC_256"
);
for (s_mfpr, a_mfpr) in vec![(false, false), (false, true), (true, false), (true, true)] {
check_rsne_get_group_mgmt_cipher(
Some((true, s_mfpr)),
None,
Some((true, a_mfpr)),
None,
Ok(Some(DEFAULT_GROUP_MGMT_CIPHER)),
);
}
for (s_mfpr, a_mfpr) in vec![(false, false), (false, true), (true, false), (true, true)] {
check_rsne_get_group_mgmt_cipher(
Some((true, s_mfpr)),
Some(CIPHER_BIP_CMAC_256),
Some((true, a_mfpr)),
Some(CIPHER_BIP_CMAC_256),
Ok(Some(CIPHER_BIP_CMAC_256)),
);
}
for (s_mfpc, a_mfpc) in vec![(false, false), (false, true), (true, false)] {
check_rsne_get_group_mgmt_cipher(
Some((s_mfpc, false)),
Some(CIPHER_BIP_CMAC_128),
Some((a_mfpc, false)),
None,
Ok(None),
);
}
for (s_mfpc, a_mfpc) in vec![(false, false), (false, true), (true, false)] {
check_rsne_get_group_mgmt_cipher(
Some((s_mfpc, false)),
None,
Some((a_mfpc, false)),
None,
Ok(None),
);
}
let s_protection_info = ProtectionInfo::LegacyWpa(fake_deprecated_wpa1_vendor_ie());
let a_protection_info = make_protection_info_with_mfp_parameters(None, None);
assert_eq!(get_group_mgmt_cipher(&s_protection_info, &a_protection_info), Ok(None));
let s_protection_info = make_protection_info_with_mfp_parameters(None, None);
let a_protection_info = ProtectionInfo::LegacyWpa(fake_deprecated_wpa1_vendor_ie());
assert_eq!(get_group_mgmt_cipher(&s_protection_info, &a_protection_info), Ok(None));
}
#[test]
fn test_get_group_mgmt_cipher_errors() {
check_rsne_get_group_mgmt_cipher(
Some((false, true)),
None,
Some((false, false)),
None,
Err(Error::InvalidClientMgmtFrameProtectionCapabilityBit),
);
check_rsne_get_group_mgmt_cipher(
Some((false, false)),
None,
Some((false, true)),
None,
Err(Error::InvalidApMgmtFrameProtectionCapabilityBit),
);
check_rsne_get_group_mgmt_cipher(
Some((true, true)),
None,
Some((false, false)),
None,
Err(Error::MgmtFrameProtectionRequiredByClient),
);
check_rsne_get_group_mgmt_cipher(
Some((true, true)),
None,
None,
None,
Err(Error::MgmtFrameProtectionRequiredByClient),
);
let s_protection_info = make_protection_info_with_mfp_parameters(Some((true, true)), None);
let a_protection_info = ProtectionInfo::LegacyWpa(fake_deprecated_wpa1_vendor_ie());
assert_eq!(
get_group_mgmt_cipher(&s_protection_info, &a_protection_info),
Err(Error::MgmtFrameProtectionRequiredByClient)
);
check_rsne_get_group_mgmt_cipher(
Some((false, false)),
None,
Some((true, true)),
None,
Err(Error::MgmtFrameProtectionRequiredByAp),
);
check_rsne_get_group_mgmt_cipher(
None,
None,
Some((true, true)),
None,
Err(Error::MgmtFrameProtectionRequiredByAp),
);
let s_protection_info = ProtectionInfo::LegacyWpa(fake_deprecated_wpa1_vendor_ie());
let a_protection_info = make_protection_info_with_mfp_parameters(Some((true, true)), None);
assert_eq!(
get_group_mgmt_cipher(&s_protection_info, &a_protection_info),
Err(Error::MgmtFrameProtectionRequiredByAp)
);
check_rsne_get_group_mgmt_cipher(
Some((true, true)),
Some(CIPHER_BIP_CMAC_128),
Some((true, true)),
Some(CIPHER_BIP_CMAC_256),
Err(Error::GroupMgmtCipherMismatch(CIPHER_BIP_CMAC_128, CIPHER_BIP_CMAC_256)),
);
check_rsne_get_group_mgmt_cipher(
Some((true, true)),
Some(CIPHER_BIP_CMAC_256),
Some((true, true)),
None,
Err(Error::GroupMgmtCipherMismatch(CIPHER_BIP_CMAC_256, DEFAULT_GROUP_MGMT_CIPHER)),
);
}
}