// Syd: rock-solid application kernel
// src/kernel/net/socket.rs: socket(2) and socketpair(2) handlers
//
// Copyright (c) 2025, 2026 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

#[expect(deprecated)]
use libc::SOCK_PACKET;
use libc::{
    c_int, AF_ALG, AF_INET, AF_INET6, AF_NETLINK, AF_PACKET, AF_UNIX, SOCK_CLOEXEC, SOCK_NONBLOCK,
    SOCK_RAW,
};
use libseccomp::ScmpNotifResp;
use nix::{
    errno::Errno,
    sys::socket::{AddressFamily, SockFlag, SockType},
};

use crate::{
    compat::{fstatx, STATX_INO},
    confine::is_valid_ptr,
    cookie::{safe_socket, safe_socketpair},
    debug,
    kernel::net::sandbox_addr_unnamed,
    log_enabled,
    req::UNotifyEventRequest,
    sandbox::{Flags, NetlinkFamily, Options, SandboxGuard},
    syslog::LogLevel,
};

#[expect(clippy::cognitive_complexity)]
pub(crate) fn handle_socket(
    request: &UNotifyEventRequest,
    args: &[u64; 6],
    flags: Flags,
    options: Options,
    netlink_families: NetlinkFamily,
) -> Result<ScmpNotifResp, Errno> {
    let allow_unsafe_socket = options.allow_unsafe_socket();
    let allow_unsupp_socket = options.allow_unsupp_socket();
    let allow_safe_kcapi = options.allow_safe_kcapi();
    let force_cloexec = flags.force_cloexec();
    let force_rand_fd = flags.force_rand_fd();

    let domain = c_int::try_from(args[0]).or(Err(Errno::EINVAL))?;
    let stype = c_int::try_from(args[1]).or(Err(Errno::EINVAL))?;
    let proto = c_int::try_from(args[2]).or(Err(Errno::EINVAL))?;

    // SAFETY:
    // 1. Limit available domains based on sandbox flags.
    // 2. Deny access to raw & packet sockets,
    //    unless trace/allow_unsafe_socket:1 is set.
    //    Both types require CAP_NET_RAW and use of
    //    SOCK_PACKET is strongly discouraged.
    #[expect(deprecated, reason = "SOCK_PACKET is deprecated")]
    if !allow_unsupp_socket {
        match domain {
            AF_UNIX | AF_INET | AF_INET6 => {}
            AF_ALG if allow_safe_kcapi => {}
            AF_NETLINK => {
                // Restrict AF_NETLINK to the allowlisted families.
                let nlfam = u32::try_from(args[2]).or(Err(Errno::EINVAL))?;
                #[expect(clippy::cast_sign_loss)]
                if nlfam > NetlinkFamily::max() as u32 {
                    return Err(Errno::EINVAL);
                }
                let nlfam = NetlinkFamily::from_bits(1 << nlfam).ok_or(Errno::EINVAL)?;
                if !netlink_families.contains(nlfam) {
                    // SAFETY: Unsafe netlink family, deny.
                    return Err(Errno::EAFNOSUPPORT);
                }
            }
            AF_PACKET if !allow_unsafe_socket => return Err(Errno::EACCES),
            AF_PACKET => {}
            _ => return Err(Errno::EAFNOSUPPORT),
        }
    } else if !allow_safe_kcapi && domain == AF_ALG {
        return Err(Errno::EAFNOSUPPORT);
    } else if !allow_unsafe_socket
        && (domain == AF_PACKET
            || matches!(stype & (SOCK_RAW | SOCK_PACKET), SOCK_RAW | SOCK_PACKET))
    {
        return Err(Errno::EACCES);
    } else {
        // SAFETY: allow_unsupp_socket:1
        // Safe domain, allow.
    }

    let cloexec = force_cloexec || (stype & SOCK_CLOEXEC != 0);
    let stype = stype | SOCK_CLOEXEC;
    let fd = safe_socket(domain, stype, proto)?;

    if log_enabled!(LogLevel::Debug) {
        let inode = fstatx(&fd, STATX_INO)
            .map(|statx| statx.stx_ino)
            .unwrap_or(0);
        let domain = AddressFamily::from_i32(domain)
            .map(|af| format!("{af:?}"))
            .unwrap_or_else(|| "?".to_string());
        let flags = SockFlag::from_bits_retain(stype & (SOCK_CLOEXEC | SOCK_NONBLOCK));
        let stype = SockType::try_from(stype & !(SOCK_CLOEXEC | SOCK_NONBLOCK))
            .map(|st| format!("{st:?}"))
            .unwrap_or_else(|_| "?".to_string());
        debug!("ctx": "net", "op": "create_socket",
                "msg": format!("created {domain} {stype} socket with inode:{inode:#x}"),
                "domain": domain,
                "type": stype,
                "protocol": proto,
                "flags": flags.bits(),
                "inode": inode);
    }

    request.send_fd(fd, cloexec, force_rand_fd)
}

#[expect(clippy::cognitive_complexity)]
pub(crate) fn handle_socketpair(
    request: &UNotifyEventRequest,
    sandbox: SandboxGuard,
    args: &[u64; 6],
    op: u8,
) -> Result<ScmpNotifResp, Errno> {
    let flags = *sandbox.flags;
    let options = *sandbox.options;
    let force_cloexec = flags.force_cloexec();
    let force_rand_fd = flags.force_rand_fd();
    let allow_unsupp_socket = options.allow_unsupp_socket();

    let domain = c_int::try_from(args[0])
        .ok()
        .and_then(AddressFamily::from_i32)
        .ok_or(Errno::EAFNOSUPPORT)?;

    let stype = c_int::try_from(args[1]).or(Err(Errno::EINVAL))?;
    let sflag = SockFlag::from_bits(stype & (SOCK_CLOEXEC | SOCK_NONBLOCK)).ok_or(Errno::EINVAL)?;
    let stype =
        SockType::try_from(stype & !(SOCK_CLOEXEC | SOCK_NONBLOCK)).or(Err(Errno::EINVAL))?;

    let proto = c_int::try_from(args[2]).or(Err(Errno::EINVAL))?;

    // On Linux, the only supported domains for this call are AF_UNIX (or
    // synonymously, AF_LOCAL) and AF_TIPC (since Linux 4.12).
    let check_access = match domain {
        _ if stype == SockType::Raw => return Err(Errno::EPROTONOSUPPORT),
        AddressFamily::Unix if !matches!(proto, 0 | libc::AF_UNIX) => {
            return Err(Errno::EPROTONOSUPPORT)
        }
        AddressFamily::Unix => true,
        AddressFamily::Tipc if allow_unsupp_socket => false,
        _ if stype == SockType::Datagram && !matches!(proto, 0 | libc::IPPROTO_UDP) => {
            return Err(Errno::EPROTONOSUPPORT)
        }
        _ if stype == SockType::Stream && !matches!(proto, 0 | libc::IPPROTO_TCP) => {
            return Err(Errno::EPROTONOSUPPORT)
        }
        _ => return Err(Errno::EOPNOTSUPP),
    };

    // Check AF_UNIX sockets for bind access to dummy `!unnamed' path.
    if check_access {
        sandbox_addr_unnamed(request, &sandbox, op)?;
    }
    drop(sandbox); // release the read-lock.

    // SAFETY: Check pointer against mmap_min_addr.
    let fdptr = args[3];
    if !is_valid_ptr(fdptr, request.scmpreq.data.arch) {
        return Err(Errno::EFAULT);
    }

    let cloexec = force_cloexec || sflag.contains(SockFlag::SOCK_CLOEXEC);
    let sflag = sflag | SockFlag::SOCK_CLOEXEC;

    // Create the socket pair using the hardened helper.
    let (fd0, fd1) = safe_socketpair(domain, stype, proto, sflag)?;

    // Handle UNIX map after successful socketpair(2) for UNIX sockets.
    if domain == AddressFamily::Unix {
        // Record inode->PID mappings to the UNIX map.
        // We ignore errors because there's nothing we can do about them.
        let _ = request.add_unix(&fd0, request.scmpreq.pid(), None, None);
        let _ = request.add_unix(&fd1, request.scmpreq.pid(), None, None);
    }

    if log_enabled!(LogLevel::Debug) {
        let inode0 = fstatx(&fd0, STATX_INO)
            .map(|statx| statx.stx_ino)
            .unwrap_or(0);
        let inode1 = fstatx(&fd1, STATX_INO)
            .map(|statx| statx.stx_ino)
            .unwrap_or(0);
        let domain = format!("{domain:?}");
        let stypes = format!("{stype:?}");
        debug!("ctx": "net", "op": "create_socketpair",
               "msg": format!("created {domain} {stypes} socketpair with inodes:{inode0:#x},{inode1:#x}"),
               "domain": domain,
               "type": stypes,
               "protocol": proto,
               "flags": sflag.bits(),
               "inode0": inode0,
               "inode1": inode1);
    }

    // Install both fds into the sandbox process.
    // Move fds into the function and close on return.
    let newfd0 = request.add_fd(fd0, cloexec, force_rand_fd)?;
    let newfd1 = request.add_fd(fd1, cloexec, force_rand_fd)?;

    // Write the installed fds back to sandbox process memory.
    let a = newfd0.to_ne_bytes();
    let b = newfd1.to_ne_bytes();
    let out = [a[0], a[1], a[2], a[3], b[0], b[1], b[2], b[3]];

    // The caller provided `fdptr`:
    // Write back exactly 2 * sizeof(RawFd) bytes.
    request.write_mem(&out, fdptr)?;

    // socketpair(2) returns 0 on success.
    Ok(request.return_syscall(0))
}
