// ABOUTME: greetd IPC protocol implementation — communicates via Unix socket. // ABOUTME: Uses length-prefixed JSON encoding as specified by the greetd IPC protocol. use std::io::{self, Read, Write}; use std::os::unix::net::UnixStream; const MAX_PAYLOAD_SIZE: usize = 65536; /// Errors from greetd IPC communication. #[derive(Debug)] pub enum IpcError { Io(io::Error), PayloadTooLarge(usize), Json(serde_json::Error), ConnectionClosed, } impl std::fmt::Display for IpcError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { IpcError::Io(e) => write!(f, "IPC I/O error: {e}"), IpcError::PayloadTooLarge(size) => { write!(f, "Payload too large: {size} bytes (max {MAX_PAYLOAD_SIZE})") } IpcError::Json(e) => write!(f, "IPC JSON error: {e}"), IpcError::ConnectionClosed => write!(f, "Connection closed while reading data"), } } } impl std::error::Error for IpcError {} impl From for IpcError { fn from(e: io::Error) -> Self { IpcError::Io(e) } } impl From for IpcError { fn from(e: serde_json::Error) -> Self { IpcError::Json(e) } } /// Read exactly 4 bytes (length header) from the stream into a stack array. fn recv_header(stream: &mut UnixStream) -> Result<[u8; 4], IpcError> { let mut buf = [0u8; 4]; let mut filled = 0; while filled < 4 { let bytes_read = stream.read(&mut buf[filled..])?; if bytes_read == 0 { return Err(IpcError::ConnectionClosed); } filled += bytes_read; } Ok(buf) } /// Receive exactly n bytes from the stream, looping on partial reads. fn recv_payload(stream: &mut UnixStream, n: usize) -> Result, IpcError> { let mut buf = vec![0u8; n]; let mut filled = 0; while filled < n { let bytes_read = stream.read(&mut buf[filled..])?; if bytes_read == 0 { return Err(IpcError::ConnectionClosed); } filled += bytes_read; } Ok(buf) } /// Send a length-prefixed JSON message to the greetd socket. pub fn send_message( stream: &mut UnixStream, msg: &serde_json::Value, ) -> Result<(), IpcError> { let payload = serde_json::to_vec(msg)?; if payload.len() > MAX_PAYLOAD_SIZE { return Err(IpcError::PayloadTooLarge(payload.len())); } let msg_type = msg.get("type").and_then(|v| v.as_str()).unwrap_or("unknown"); log::debug!("IPC send: type={msg_type}, size={} bytes", payload.len()); let header = (payload.len() as u32).to_le_bytes(); stream.write_all(&header)?; stream.write_all(&payload)?; Ok(()) } /// Receive a length-prefixed JSON message from the greetd socket. pub fn recv_message(stream: &mut UnixStream) -> Result { let header = recv_header(stream)?; let length = u32::from_le_bytes(header) as usize; if length > MAX_PAYLOAD_SIZE { return Err(IpcError::PayloadTooLarge(length)); } let payload = recv_payload(stream, length)?; let value: serde_json::Value = serde_json::from_slice(&payload)?; let msg_type = value.get("type").and_then(|v| v.as_str()).unwrap_or("unknown"); log::debug!("IPC recv: type={msg_type}, size={length} bytes"); Ok(value) } /// Send a create_session request to greetd and return the response. pub fn create_session( stream: &mut UnixStream, username: &str, ) -> Result { let msg = serde_json::json!({ "type": "create_session", "username": username, }); send_message(stream, &msg)?; recv_message(stream) } /// Send an authentication response (e.g. password) to greetd. pub fn post_auth_response( stream: &mut UnixStream, response: Option<&str>, ) -> Result { let msg = serde_json::json!({ "type": "post_auth_message_response", "response": response, }); send_message(stream, &msg)?; recv_message(stream) } /// Send a start_session request to launch the user's session. pub fn start_session( stream: &mut UnixStream, cmd: &[String], ) -> Result { let msg = serde_json::json!({ "type": "start_session", "cmd": cmd, }); send_message(stream, &msg)?; recv_message(stream) } /// Cancel the current authentication session. pub fn cancel_session(stream: &mut UnixStream) -> Result { let msg = serde_json::json!({"type": "cancel_session"}); send_message(stream, &msg)?; recv_message(stream) } #[cfg(test)] mod tests { use super::*; use std::os::unix::net::UnixStream; /// Create a connected pair of Unix sockets for testing. fn socket_pair() -> (UnixStream, UnixStream) { UnixStream::pair().unwrap() } #[test] fn send_and_receive_message() { let (mut client, mut server) = socket_pair(); let msg = serde_json::json!({"type": "create_session", "username": "test"}); send_message(&mut client, &msg).unwrap(); let received = recv_message(&mut server).unwrap(); assert_eq!(received["type"], "create_session"); assert_eq!(received["username"], "test"); } #[test] fn create_session_roundtrip() { let (mut client, mut server) = socket_pair(); // Simulate greetd response in a thread let handle = std::thread::spawn(move || { let msg = recv_message(&mut server).unwrap(); assert_eq!(msg["type"], "create_session"); assert_eq!(msg["username"], "alice"); let response = serde_json::json!({ "type": "auth_message", "auth_message_type": "visible", "auth_message": "Password: ", }); send_message(&mut server, &response).unwrap(); }); let response = create_session(&mut client, "alice").unwrap(); assert_eq!(response["type"], "auth_message"); handle.join().unwrap(); } #[test] fn post_auth_response_roundtrip() { let (mut client, mut server) = socket_pair(); let handle = std::thread::spawn(move || { let msg = recv_message(&mut server).unwrap(); assert_eq!(msg["type"], "post_auth_message_response"); assert_eq!(msg["response"], "secret123"); let response = serde_json::json!({"type": "success"}); send_message(&mut server, &response).unwrap(); }); let response = post_auth_response(&mut client, Some("secret123")).unwrap(); assert_eq!(response["type"], "success"); handle.join().unwrap(); } #[test] fn start_session_roundtrip() { let (mut client, mut server) = socket_pair(); let handle = std::thread::spawn(move || { let msg = recv_message(&mut server).unwrap(); assert_eq!(msg["type"], "start_session"); assert_eq!(msg["cmd"], serde_json::json!(["niri-session"])); let response = serde_json::json!({"type": "success"}); send_message(&mut server, &response).unwrap(); }); let cmd = vec!["niri-session".to_string()]; let response = start_session(&mut client, &cmd).unwrap(); assert_eq!(response["type"], "success"); handle.join().unwrap(); } #[test] fn cancel_session_roundtrip() { let (mut client, mut server) = socket_pair(); let handle = std::thread::spawn(move || { let msg = recv_message(&mut server).unwrap(); assert_eq!(msg["type"], "cancel_session"); let response = serde_json::json!({"type": "success"}); send_message(&mut server, &response).unwrap(); }); let response = cancel_session(&mut client).unwrap(); assert_eq!(response["type"], "success"); handle.join().unwrap(); } #[test] fn connection_closed_returns_error() { let (mut client, server) = socket_pair(); drop(server); let result = recv_message(&mut client); assert!(result.is_err()); } #[test] fn oversized_payload_rejected_on_send() { let (mut client, _server) = socket_pair(); let big_string = "x".repeat(MAX_PAYLOAD_SIZE + 1); let msg = serde_json::json!({"data": big_string}); let result = send_message(&mut client, &msg); assert!(result.is_err()); } #[test] fn oversized_payload_rejected_on_receive() { let (mut client, mut server) = socket_pair(); // Manually send a header claiming a huge payload let fake_length: u32 = (MAX_PAYLOAD_SIZE as u32) + 1; server.write_all(&fake_length.to_le_bytes()).unwrap(); let result = recv_message(&mut client); assert!(matches!(result, Err(IpcError::PayloadTooLarge(_)))); } #[test] fn ipc_error_display() { let err = IpcError::ConnectionClosed; assert_eq!(err.to_string(), "Connection closed while reading data"); let err = IpcError::PayloadTooLarge(99999); assert!(err.to_string().contains("99999")); } }