use http::{self, Response, StatusCode};
use thiserror::Error;
use tokio_tungstenite::tungstenite as ws;
use crate::client::Body;
pub const WS_PROTOCOL: &str = "v4.channel.k8s.io";
#[cfg(feature = "ws")]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
#[derive(Debug, Error)]
pub enum UpgradeConnectionError {
#[error("failed to switch protocol: {0}")]
ProtocolSwitch(http::status::StatusCode),
#[error("upgrade header was not set to websocket")]
MissingUpgradeWebSocketHeader,
#[error("connection header was not set to Upgrade")]
MissingConnectionUpgradeHeader,
#[error("Sec-WebSocket-Accept key mismatched")]
SecWebSocketAcceptKeyMismatch,
#[error("Sec-WebSocket-Protocol mismatched")]
SecWebSocketProtocolMismatch,
#[error("failed to get pending HTTP upgrade: {0}")]
GetPendingUpgrade(#[source] hyper::Error),
}
pub fn verify_response(res: &Response<Body>, key: &str) -> Result<(), UpgradeConnectionError> {
if res.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(UpgradeConnectionError::ProtocolSwitch(res.status()));
}
let headers = res.headers();
if !headers
.get(http::header::UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(UpgradeConnectionError::MissingUpgradeWebSocketHeader);
}
if !headers
.get(http::header::CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(UpgradeConnectionError::MissingConnectionUpgradeHeader);
}
let accept_key = ws::handshake::derive_accept_key(key.as_ref());
if !headers
.get(http::header::SEC_WEBSOCKET_ACCEPT)
.map(|h| h == &accept_key)
.unwrap_or(false)
{
return Err(UpgradeConnectionError::SecWebSocketAcceptKeyMismatch);
}
if !headers
.get(http::header::SEC_WEBSOCKET_PROTOCOL)
.map(|h| h == WS_PROTOCOL)
.unwrap_or(false)
{
return Err(UpgradeConnectionError::SecWebSocketProtocolMismatch);
}
Ok(())
}
pub fn sec_websocket_key() -> String {
use base64::Engine;
let r: [u8; 16] = rand::random();
base64::engine::general_purpose::STANDARD.encode(r)
}