#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/init.h>
#include <linux/net.h>
#include <linux/inet.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <linux/thread_migrate.h>
#include <linux/slab.h>
#include <linux/uaccess.h>
#include <linux/string.h>
#include <linux/inetdevice.h>
#include <linux/time.h>
#include <linux/ktime.h>

#define VM2_IP "192.168.123.79"
#define VM1_IP "192.168.123.78"  

#define DEST_PORT 1104           // Destination UDP port
struct socket *sock;
struct sockaddr_in remote_addr = {0};
struct msghdr msg = {0};
struct kvec iov;

static uint32_t vmsg_ip_to_uint32(char *ip)
{
    uint32_t a, b, c, d;  // Variables to hold each octet of the IP address
    char ch;              // Variable to store the separator characters (dots)

    /* Parse the IP address string into four integers. The expected format is
     * "xxx.xxx.xxx.xxx", where each "xxx" is a number between 0 and 255.
     * The sscanf function reads the string and extracts the numbers, placing
     * them into a, b, c, and d. The 'ch' variable is used to ensure the correct
     * number of dots are present.
     */
    if (sscanf(ip, "%u%c%u%c%u%c%u", &a, &ch, &b, &ch, &c, &ch, &d) != 7) {
        pr_err("vmsg_ip_to_uint32: Invalid IP address format: %s\n", ip);  // Log an error if parsing fails
        return 0;  // Return 0 to indicate an invalid IP address
    }

    /* Validate each octet to ensure it is within the range [0, 255].
     * If any octet is out of range, log an error and return 0.
     */
    if (a > 255 || b > 255 || c > 255 || d > 255) {
        pr_err("vmsg_ip_to_uint32: IP address octet out of range: %s\n", ip);  // Log an error if any octet is invalid
        return 0;  // Return 0 for invalid input
    }

    /* Combine the four octets into a single 32-bit integer.
     * The result is a 32-bit value where each byte represents an octet
     * of the IP address, in network byte order (big-endian).
     */
    return (a << 24) | (b << 16) | (c << 8) | d;
}

static int is_local_ip(uint32_t ip)
{
    struct net_device *dev;         // Pointer to a network device structure
    struct in_device *in_dev;       // Pointer to an in_device structure for IPv4 configuration
    struct in_ifaddr *if_info;      // Pointer to an in_ifaddr structure for interface addresses
    int is_local = 0;               // Flag to indicate if the IP is local
    uint32_t ip_network_order = htonl(ip);  // Convert IP to network byte order

    rtnl_lock();  // Lock the network device list to ensure thread safety during iteration

    /* Iterate over each network device (network interface) in the system.
     * The for_each_netdev macro simplifies the iteration process.
     */
    for_each_netdev(&init_net, dev) {
        in_dev = __in_dev_get_rtnl(dev);  // Get the IPv4 configuration for the device

        /* If the device has an IPv4 configuration, check its assigned IP addresses. */
        if (in_dev) {
            in_dev_for_each_ifa_rtnl(if_info, in_dev) {
                /* Compare each IP address on the interface with the input IP.
                 * If a match is found, set the is_local flag and break out of the loop.
                 */
                if (if_info->ifa_local == ip_network_order) {
                    is_local = 1;  // Mark the IP as local
                    goto out;  // Exit the loop early if a match is found
                }
            }
        }
    }

out:
    rtnl_unlock();  // Unlock the network device list
    return is_local;  // Return whether the IP is local or not
}

static int udp_module_init(void) {
    int ret;

    ret = sock_create_kern(&init_net, AF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
    if (ret < 0) {
        printk(KERN_ERR "Failed to create socket\n");
        return ret;
    }
    char *dest_ip = VM2_IP;
    if (is_local_ip(vmsg_ip_to_uint32(VM2_IP))) 
        dest_ip = VM1_IP;
    
    remote_addr.sin_family = AF_INET;
    remote_addr.sin_port = htons(DEST_PORT); // Set destination port
    remote_addr.sin_addr.s_addr = in_aton(dest_ip);

    return 0;
}

static void udp_module_exit(void) {
    sock_release(sock);
    if (sock) sock = NULL;
}

int call_remote_storage(struct remote_request *request) {
    char *data, *mems, *vma_descs;
    // int retries = 5;  // Number of retries
    // ktime_t start, end;
    // s64 diff;

    // start = ktime_get();

    data = kmalloc(11264, GFP_KERNEL);
    if (!data) {
        printk(KERN_ERR "Remote: Failed to allocate memory for buffer\n");
        return -ENOMEM;
    }

    mems = kmalloc(1024, GFP_KERNEL);
    if (!mems) {
        printk(KERN_ERR "Remote: Failed to allocate memory for buffer\n");
        return -ENOMEM;
    }

    vma_descs = kmalloc(8192, GFP_KERNEL);
    if (!vma_descs) {
        printk(KERN_ERR "Remote: Failed to allocate memory for buffer\n");
        return -ENOMEM;
    }
    pr_info("Starting to serialize remote wrapper\n");
    // Append remote wrapper data to the message
    sprintf(mems, "%d,%d,%d,%s,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%u,%lu,%s,%lu,%lu,%d",
            request->remote->pid,
            request->remote->opid,
            request->remote->otgid,
            request->remote_name,
            request->remote->task_size,
            request->remote->stack_start,
            request->remote->env_start,
            request->remote->env_end,
            request->remote->arg_start,
            request->remote->arg_end,
            request->remote->start_brk,
            request->remote->brk,
            request->remote->start_code,
            request->remote->end_code,
            request->remote->start_data,
            request->remote->end_data,
            request->remote->personality,
            request->remote->def_flags,
            request->remote->exe_path,
            request->remote->arch.fs,
            request->remote->arch.gs,
            request->remote->map_count);
    pr_info("Serialized mems\n");
    // String to serialize pt_regs
    char pt_regs_str[1024];
    pt_regs_str[0] = '\0';
#if defined(__x86_64__)
    sprintf(pt_regs_str, "%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%u,%lu,%lu,%u",
            request->remote->arch.regsets.r15,
            request->remote->arch.regsets.r14,
            request->remote->arch.regsets.r13,
            request->remote->arch.regsets.r12,
            request->remote->arch.regsets.bp,
            request->remote->arch.regsets.bx,
            request->remote->arch.regsets.r11,
            request->remote->arch.regsets.r10,
            request->remote->arch.regsets.r9,
            request->remote->arch.regsets.r8,
            request->remote->arch.regsets.ax,
            request->remote->arch.regsets.cx,
            request->remote->arch.regsets.dx,
            request->remote->arch.regsets.si,
            request->remote->arch.regsets.di,
            request->remote->arch.regsets.orig_ax,
            request->remote->arch.regsets.ip,
            request->remote->arch.regsets.cs,
            request->remote->arch.regsets.flags,
            request->remote->arch.regsets.sp,
            request->remote->arch.regsets.ss);
    pr_info("Serialized pt_regs\n");
            pr_info("=== SERIALIZATION DEBUG (Node A) ===\n");
            pr_info("Original values from remote_wrapper:\n");
            pr_info("  IP: 0x%lx\n", request->remote->arch.regsets.ip);
            pr_info("  SP: 0x%lx\n", request->remote->arch.regsets.sp);
            pr_info("  CS: 0x%x, SS: 0x%x\n", request->remote->arch.regsets.cs, request->remote->arch.regsets.ss);

            pr_info("Serialized pt_regs_str (first 200 chars): %.200s\n", pt_regs_str);
            pr_info("Final data packet (first 300 chars): %.300s\n", data);

#elif defined(__i386__)
    sprintf(pt_regs_str, "%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu,%lu",
            request->remote->arch.regsets.bx,
            request->remote->arch.regsets.cx,
            request->remote->arch.regsets.dx,
            request->remote->arch.regsets.si,
            request->remote->arch.regsets.di,
            request->remote->arch.regsets.bp,
            request->remote->arch.regsets.ax,
            request->remote->arch.regsets.ds,
            request->remote->arch.regsets.es,
            request->remote->arch.regsets.fs,
            request->remote->arch.regsets.gs,
            request->remote->arch.regsets.orig_ax,
            request->remote->arch.regsets.ip,
            request->remote->arch.regsets.cs,
            request->remote->arch.regsets.flags,
            request->remote->arch.regsets.sp,
            request->remote->arch.regsets.ss);
#else
    // If this code is compiled for an x86 target that is neither __x86_64__ nor __i386__,
    // (or if these macros are not defined by the compiler), the build will fail.
    // This is safer than silently not serializing registers.
    pt_regs_str[0] = '\0'; // Ensure pt_regs_str is an empty string if error is removed.
#error "Cannot serialize pt_regs: unsupported x86 architecture. __x86_64__ or __i386__ must be defined."
#endif
    char **temp_vma = kmalloc(request->remote->map_count * sizeof(char *), GFP_KERNEL);
    if (!temp_vma) {
        printk(KERN_ERR "Remote: Failed to allocate memory for temp vma\n");
        return -ENOMEM;
    }
    for (int i = 0; i < request->remote->map_count; i++) {
        temp_vma[i] = kmalloc(256, GFP_KERNEL);
        if (!temp_vma[i]) {
            for (int j = 0; j < i; j++) {
                kfree(temp_vma[j]);
            }
            kfree(temp_vma);
            return -ENOMEM;
        }
    }

    pr_info("Allocated temp vma %d\n", request->remote->map_count);

    for (int i = 0; i < request->remote->map_count; i++) {
        pr_info("Filling temp vma %d %lu\n", i, request->remote->vma_descs[i].start);
        sprintf(temp_vma[i], "{%lu,%lu,%lu,%lu,%s}",
                request->remote->vma_descs[i].start,
                request->remote->vma_descs[i].end,
                request->remote->vma_descs[i].flags,
                request->remote->vma_descs[i].pgoff,
                request->remote->vma_descs[i].file_path);
    }
    pr_info("Filled temp vma strings %d\n", request->remote->map_count);
    
    for (int i = 0; i < request->remote->map_count; i++) {
        strcat(vma_descs, temp_vma[i]);
        if (i < request->remote->map_count - 1) {
            strcat(vma_descs, ",");
        }
    }

    pr_info("Filled vma descs strings\n");

    sprintf(data, "%s,%s,[%s]", mems, pt_regs_str, vma_descs);
    // sprintf(data, "%s,%s", mems, pt_regs_str);
    // size_t data_len = strlen(data);
    int ret = 0;
    size_t data_len = strlen(data);

    //Initialize the socket
    if (!sock) {
        ret = udp_module_init();
        if (ret < 0) {
            printk(KERN_ERR "Remote: UDP Client: Failed to initialize socket, error %d\n", ret);
            return ret;
        }
    }

    iov.iov_base = data; // Message data
    iov.iov_len = data_len;

    msg.msg_name = &remote_addr; // Set destination address
    msg.msg_namelen = sizeof(remote_addr);
 
    
    ret = kernel_sendmsg(sock, &msg, &iov, 1, data_len);
    
    if (ret < 0) {
        printk(KERN_ERR "Remote: Client: Failed to send message, error %d\n", ret);
        goto error;
    }

    
    // end = ktime_get();

    // diff = ktime_to_ns(ktime_sub(end, start));

    // pr_info("Time taken to send message: %lld ns + %d retries\n", diff, 5 - retries);

error:
    for (int i = 0; i < request->remote->map_count; i++) {
        kfree(temp_vma[i]);
    }
    kfree(temp_vma);
    kfree(mems);
    kfree(vma_descs);
    kfree(data);
    udp_module_exit();
    return ret;
}
EXPORT_SYMBOL(call_remote_storage);