1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::ascii::AsciiExt;
use std::str::{from_utf8};
use super::{Head};
use websocket::Accept;
#[derive(Debug)]
pub struct WebsocketHandshake {
pub accept: Accept,
pub protocols: Vec<String>,
pub extensions: Vec<String>,
}
fn bytes_trim(mut x: &[u8]) -> &[u8] {
while x.len() > 0 && matches!(x[0], b'\r' | b'\n' | b' ' | b'\t') {
x = &x[1..];
}
while x.len() > 0 && matches!(x[x.len()-1], b'\r' | b'\n' | b' ' | b'\t')
{
x = &x[..x.len()-1];
}
return x;
}
pub fn get_handshake(req: &Head) -> Result<Option<WebsocketHandshake>, ()> {
let conn_upgrade = req.connection_header().map(|x| {
x.split(',').any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
});
if !conn_upgrade.unwrap_or(false) {
return Ok(None);
}
if req.path().is_none() {
debug!("Invalid request-target for websocket request");
return Err(());
}
let mut upgrade = false;
let mut version = false;
let mut accept = None;
let mut protocols = Vec::new();
let mut extensions = Vec::new();
for h in req.all_headers() {
if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") {
if accept.is_some() {
debug!("Duplicate Sec-WebSocket-Key");
return Err(());
}
accept = Some(Accept::from_key_bytes(bytes_trim(h.value)));
} else if h.name.eq_ignore_ascii_case("Sec-WebSocket-Version") {
if bytes_trim(h.value) != b"13" {
debug!("Bad websocket version {:?}",
String::from_utf8_lossy(h.value));
return Err(());
} else {
version = true;
}
} else if h.name.eq_ignore_ascii_case("Sec-WebSocket-Protocol") {
let tokens = from_utf8(h.value)
.map_err(|_| debug!("Bad utf-8 in Sec-Websocket-Protocol"))?;
protocols.extend(tokens.split(",")
.map(|x| x.trim())
.filter(|x| x.len() > 0)
.map(|x| x.to_string()));
} else if h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") {
let tokens = from_utf8(h.value)
.map_err(|_| debug!("Bad utf-8 in Sec-Websocket-Extensions"))?;
extensions.extend(tokens.split(",")
.map(|x| x.trim())
.filter(|x| x.len() > 0)
.map(|x| x.to_string()));
} else if h.name.eq_ignore_ascii_case("Upgrade") {
if !h.value.eq_ignore_ascii_case(b"websocket") {
return Ok(None);
} else {
upgrade = true;
}
}
}
if req.has_body() {
debug!("Websocket handshake has payload");
return Err(());
}
if !upgrade {
debug!("No upgrade header for a websocket");
return Err(());
}
if !version || accept.is_none() {
debug!("No required headers for a websocket");
return Err(());
}
Ok(Some(WebsocketHandshake {
accept: accept.take().unwrap(),
protocols: protocols,
extensions: extensions,
}))
}