// Syd: rock-solid application kernel
// src/kernel/net/recvfrom.rs: recvfrom(2) handler
//
// Copyright (c) 2025, 2026 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    mem::MaybeUninit,
    os::{
        fd::{AsFd, AsRawFd, OwnedFd},
        unix::ffi::OsStrExt,
    },
};

use libseccomp::ScmpNotifResp;
use nix::{
    errno::Errno,
    sys::socket::{recv, SockaddrLike, SockaddrStorage},
};
use zeroize::Zeroizing;

use crate::{
    compat::MsgFlags,
    fd::{get_nonblock, has_recv_timeout},
    kernel::net::to_msgflags,
    path::XPath,
    req::UNotifyEventRequest,
};

pub(crate) fn handle_recv(
    fd: OwnedFd,
    args: &[u64; 6],
    request: &UNotifyEventRequest,
    restrict_oob: bool,
) -> Result<ScmpNotifResp, Errno> {
    // SAFETY: Reject undefined/invalid flags.
    let flags = to_msgflags(args[3])?;

    // SAFETY: Reject MSG_OOB as necessary.
    if restrict_oob && flags.contains(MsgFlags::MSG_OOB) {
        // Signal no support to let the sandbox process
        // handle the error gracefully. This is consistent
        // with the Linux kernel.
        return Err(Errno::EOPNOTSUPP);
    }

    // SAFETY:
    // 1. The length argument to the recv call
    //    must not be fully trusted, it can be overly large,
    //    and allocating a Vector of that capacity may overflow.
    // 2. It is valid for the length to be zero to receive an empty message.
    // 3. Buffer read from kernel MUST be zeroized on drop.
    let len = usize::try_from(args[2])
        .or(Err(Errno::EINVAL))?
        .min(1000000); // Cap count at 1mio.
    let mut buf = Zeroizing::new(Vec::new());
    if len > 0 {
        buf.try_reserve(len).or(Err(Errno::ENOMEM))?;
        buf.resize(len, 0);
    }

    // SAFETY: Record blocking call so it can get invalidated.
    let req = request.scmpreq;
    let is_blocking = if !flags.contains(MsgFlags::MSG_DONTWAIT) && !get_nonblock(&fd)? {
        let ignore_restart = has_recv_timeout(&fd)?;

        // Record the blocking call.
        request.cache.add_sys_block(req, ignore_restart)?;

        true
    } else {
        false
    };

    let result = recv(fd.as_raw_fd(), &mut buf, flags.into());

    // Remove invalidation record unless interrupted.
    if is_blocking {
        request
            .cache
            .del_sys_block(req.id, matches!(result, Err(Errno::EINTR)))?;
    }

    // Check for recv errors after invalidation.
    let n = result?;

    // Write buffer into sandbox process memory.
    request.write_mem(&buf[..n], args[1])?;

    #[expect(clippy::cast_possible_wrap)]
    Ok(request.return_syscall(n as i64))
}

pub(crate) fn handle_recvfrom(
    fd: OwnedFd,
    args: &[u64; 6],
    request: &UNotifyEventRequest,
    restrict_oob: bool,
) -> Result<ScmpNotifResp, Errno> {
    // Determine address length if specified.
    let addrlen = if args[5] != 0 {
        const SIZEOF_SOCKLEN_T: usize = size_of::<libc::socklen_t>();
        let mut buf = [0u8; SIZEOF_SOCKLEN_T];
        if request.read_mem(&mut buf, args[5], SIZEOF_SOCKLEN_T)? == SIZEOF_SOCKLEN_T {
            // libc defines socklen_t as u32,
            // however we should check for negative values
            // and return EINVAL as necessary.
            let len = i32::from_ne_bytes(buf);
            let len = libc::socklen_t::try_from(len).or(Err(Errno::EINVAL))?;
            if len > 0 && args[4] == 0 {
                // address length is positive however address is NULL:
                // Return EINVAL and NOT EFAULT here, see LTP accept01 check.
                return Err(Errno::EINVAL);
            }
            len
        } else {
            // Invalid/short read, assume invalid address length.
            return Err(Errno::EINVAL);
        }
    } else {
        // Connection-mode socket, use recv handler.
        return handle_recv(fd, args, request, restrict_oob);
    };

    // SAFETY: Reject undefined/invalid flags.
    let flags = to_msgflags(args[3])?;

    // SAFETY: Reject MSG_OOB as necessary.
    if restrict_oob && flags.contains(MsgFlags::MSG_OOB) {
        // Signal no support to let the sandbox process
        // handle the error gracefully. This is consistent
        // with the Linux kernel.
        return Err(Errno::EOPNOTSUPP);
    }

    // Check whether we should block and ignore restarts.
    let (is_blocking, ignore_restart) = if !get_nonblock(&fd)? {
        let ignore_restart = has_recv_timeout(&fd)?;
        (true, ignore_restart)
    } else {
        (false, false)
    };

    // Do the recvfrom call.
    let (buf, mut addr) = do_recvfrom(fd, request, flags, args[2], is_blocking, ignore_restart)?;

    // Change peer address as necessary for UNIX domain sockets.
    if let Some(peer_addr) = addr
        .as_ref()
        .and_then(|addr| addr.0.as_unix_addr())
        .and_then(|unix| unix.path())
        .map(|path| XPath::from_bytes(path.as_os_str().as_bytes()))
        .filter(|path| path.starts_with(b"./"))
        .map(|path| path.split().1)
        .and_then(|base| request.find_unix_addr(base).ok())
        .and_then(|addr| {
            // SAFETY: addr is a valid UnixAddr.
            unsafe { SockaddrStorage::from_raw(addr.as_ptr().cast(), Some(addr.len())) }
        })
    {
        addr = Some((peer_addr, peer_addr.len()));
    }

    // Write buffer into sandbox process memory.
    let n = request.write_mem(&buf, args[1])?;

    // Write address into sandbox process memory as necessary.
    // The address may be None for connection-mode sockets.
    #[expect(clippy::cast_possible_truncation)]
    let len = if let Some((addr, addrlen_out)) = addr {
        // Create a byte slice from the socket address pointer.
        //
        // SAFETY: SockaddrStorage type ensures that the memory pointed
        // to by `addr.as_ptr()` is valid and properly aligned.
        let buf = unsafe { std::slice::from_raw_parts(addr.as_ptr().cast(), addr.len() as usize) };

        // Write the truncated socket address into memory.
        //
        // SAFETY: We truncate late to avoid potential UB in
        // std::slice::slice_from_raw_parts().
        let len = addrlen.min(addrlen_out) as usize;
        request.write_mem(&buf[..len], args[4])?;

        len as libc::socklen_t
    } else {
        // Connection-mode socket, write 0 to length argument.
        0
    };

    // Convert `len` into a vector of bytes.
    // SAFETY: This must be socklen_t and _not_ usize!
    let buf = len.to_ne_bytes();

    // Write `len` into memory.
    request.write_mem(&buf, args[5])?;

    #[expect(clippy::cast_possible_wrap)]
    Ok(request.return_syscall(n as i64))
}

#[expect(clippy::type_complexity)]
fn do_recvfrom<Fd: AsFd>(
    fd: Fd,
    request: &UNotifyEventRequest,
    flags: MsgFlags,
    len: u64,
    is_blocking: bool,
    ignore_restart: bool,
) -> Result<
    (
        Zeroizing<Vec<u8>>,
        Option<(SockaddrStorage, libc::socklen_t)>,
    ),
    Errno,
> {
    // SAFETY:
    // 1. The length argument to the recvfrom call
    //    must not be fully trusted, it can be overly large,
    //    and allocating a Vector of that capacity may overflow.
    // 2. It is valid for the length to be zero to receive an empty message.
    // 3. Buffer read from kernel MUST be zeroized on drop.
    let len = usize::try_from(len).or(Err(Errno::EINVAL))?.min(1000000); // Cap count at 1mio.
    let mut buf = Zeroizing::new(Vec::new());
    if len > 0 {
        buf.try_reserve(len).or(Err(Errno::ENOMEM))?;
        buf.resize(len, 0);
    }

    // Allocate properly aligned storage for the address.
    let mut addr = MaybeUninit::<SockaddrStorage>::zeroed();
    #[expect(clippy::cast_possible_truncation)]
    let mut len = size_of::<SockaddrStorage>() as libc::socklen_t;

    // Cast the aligned storage to a sockaddr pointer.
    let ptr = addr.as_mut_ptr() as *mut libc::sockaddr;

    // SAFETY: Record blocking call so it can get invalidated.
    if is_blocking {
        request
            .cache
            .add_sys_block(request.scmpreq, ignore_restart)?;
    };

    // Make the recvfrom(2) call.
    //
    // SAFETY: buf, ptr and len are valid pointers.
    let result = Errno::result(unsafe {
        libc::recvfrom(
            fd.as_fd().as_raw_fd(),
            buf.as_mut_ptr().cast(),
            buf.len() as libc::size_t,
            flags.bits(),
            ptr,
            &raw mut len,
        )
    });

    // Remove invalidation record unless interrupted.
    if is_blocking {
        request
            .cache
            .del_sys_block(request.scmpreq.id, matches!(result, Err(Errno::EINTR)))?;
    }

    // Check for recvfrom errors after invalidation.
    #[expect(clippy::cast_sign_loss)]
    let n = result? as usize;

    // Truncate buffer to the received size.
    buf.truncate(n);

    // SAFETY:
    // Convert the raw address into a SockaddrStorage structure.
    // recvfrom returned success so the pointer is valid.
    // Address may be None for connection-mode sockets.
    let addr = unsafe { SockaddrStorage::from_raw(ptr, Some(len)) }.map(|addr| (addr, len));

    Ok((buf, addr))
}
