#include "behavior_collector.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <dirent.h>
#include <time.h>
#include <errno.h>

#define MAX_PATH_LEN 4096
#define MAX_LINE_LEN 1024

static uint16_t hex_to_uint16(const char *hex) {
    return (uint16_t)strtoul(hex, NULL, 16);
}

static void parse_ip_port(const char *hex_str, char *ip_out, uint16_t *port_out) {
    // Format: AABBCCDD:EEFF (hex)
    char hex_ip[9] = {0};
    char hex_port[5] = {0};
    
    const char *colon = strchr(hex_str, ':');
    if (!colon) {
        return;
    }

    size_t ip_len = colon - hex_str;
    if (ip_len > 8) ip_len = 8;
    strncpy(hex_ip, hex_str, ip_len);
    hex_ip[ip_len] = '\0';

    strncpy(hex_port, colon + 1, 4);
    hex_port[4] = '\0';

    *port_out = hex_to_uint16(hex_port);

    // Parse IP: AABBCCDD -> D.C.B.A
    if (strlen(hex_ip) == 8) {
        char octets[4][3];
        snprintf(octets[0], 3, "%.2s", hex_ip + 6);
        snprintf(octets[1], 3, "%.2s", hex_ip + 4);
        snprintf(octets[2], 3, "%.2s", hex_ip + 2);
        snprintf(octets[3], 3, "%.2s", hex_ip);
        
        int a = (int)strtoul(octets[0], NULL, 16);
        int b = (int)strtoul(octets[1], NULL, 16);
        int c = (int)strtoul(octets[2], NULL, 16);
        int d = (int)strtoul(octets[3], NULL, 16);
        
        snprintf(ip_out, 16, "%d.%d.%d.%d", a, b, c, d);
    } else {
        strcpy(ip_out, "0.0.0.0");
    }
}

int behavior_collector_get_network(pid_t pid, NetworkData *network) {
    if (!network) {
        return -1;
    }

    network->count = 0;
    network->capacity = 16;
    network->connections = calloc(network->capacity, sizeof(NetworkConnection));
    if (!network->connections) {
        return -1;
    }

    char path[MAX_PATH_LEN];
    FILE *file;

    // Read TCP connections
    snprintf(path, sizeof(path), "/proc/%d/net/tcp", pid);
    file = fopen(path, "r");
    if (file) {
        char line[MAX_LINE_LEN];
        // Skip header line
        if (fgets(line, sizeof(line), file)) {
            while (fgets(line, sizeof(line), file)) {
                if (network->count >= network->capacity) {
                    network->capacity *= 2;
                    network->connections = realloc(network->connections, 
                                                   network->capacity * sizeof(NetworkConnection));
                    if (!network->connections) {
                        fclose(file);
                        return -1;
                    }
                }

                NetworkConnection *conn = &network->connections[network->count];
                char local_addr[32], remote_addr[32];
                uint16_t local_port, remote_port;
                
                // Format: sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt ...
                int scanned = sscanf(line, "%*d: %31s %31s %*x %*x:%*x %*x:%*x %*x %*x %*x %*d %*d %*d",
                                    local_addr, remote_addr);
                
                if (scanned >= 2) {
                    char local_ip[16];
                    parse_ip_port(local_addr, local_ip, &local_port);
                    conn->dest_ip = malloc(16);
                    parse_ip_port(remote_addr, conn->dest_ip, &remote_port);
                    conn->local_port = local_port;
                    conn->dest_port = remote_port;
                    conn->protocol = 'T';
                    network->count++;
                }
            }
        }
        fclose(file);
    }

    // Read UDP connections
    snprintf(path, sizeof(path), "/proc/%d/net/udp", pid);
    file = fopen(path, "r");
    if (file) {
        char line[MAX_LINE_LEN];
        if (fgets(line, sizeof(line), file)) {
            while (fgets(line, sizeof(line), file)) {
                if (network->count >= network->capacity) {
                    network->capacity *= 2;
                    network->connections = realloc(network->connections,
                                                   network->capacity * sizeof(NetworkConnection));
                    if (!network->connections) {
                        fclose(file);
                        return -1;
                    }
                }

                NetworkConnection *conn = &network->connections[network->count];
                char local_addr[32], remote_addr[32];
                uint16_t local_port, remote_port;
                
                int scanned = sscanf(line, "%*d: %31s %31s %*x %*x:%*x %*x:%*x %*x %*x %*x %*d %*d %*d",
                                    local_addr, remote_addr);
                
                if (scanned >= 2) {
                    char local_ip[16];
                    parse_ip_port(local_addr, local_ip, &local_port);
                    conn->dest_ip = malloc(16);
                    parse_ip_port(remote_addr, conn->dest_ip, &remote_port);
                    conn->local_port = local_port;
                    conn->dest_port = remote_port;
                    conn->protocol = 'U';
                    network->count++;
                }
            }
        }
        fclose(file);
    }

    return 0;
}

int behavior_collector_get_files(pid_t pid, FileAccessData *files) {
    if (!files) {
        return -1;
    }

    files->count = 0;
    files->capacity = 16;
    files->files = calloc(files->capacity, sizeof(char*));
    if (!files->files) {
        return -1;
    }

    char fd_dir_path[MAX_PATH_LEN];
    snprintf(fd_dir_path, sizeof(fd_dir_path), "/proc/%d/fd", pid);

    DIR *fd_dir = opendir(fd_dir_path);
    if (!fd_dir) {
        return -1;
    }

    struct dirent *entry;
    char link_target[MAX_PATH_LEN];

    while ((entry = readdir(fd_dir)) != NULL) {
        if (entry->d_name[0] == '.') {
            continue;
        }

        char fd_path[MAX_PATH_LEN];
        snprintf(fd_path, sizeof(fd_path), "%s/%s", fd_dir_path, entry->d_name);

        ssize_t len = readlink(fd_path, link_target, sizeof(link_target) - 1);
        if (len == -1) {
            continue;
        }

        link_target[len] = '\0';

        // Only track regular files (skip sockets, pipes, etc.)
        if (link_target[0] == '/') {
            if (files->count >= files->capacity) {
                files->capacity *= 2;
                files->files = realloc(files->files, files->capacity * sizeof(char*));
                if (!files->files) {
                    closedir(fd_dir);
                    return -1;
                }
            }

            files->files[files->count] = strdup(link_target);
            files->count++;
        }
    }

    closedir(fd_dir);
    return 0;
}

int behavior_collector_get_resources(pid_t pid, ResourceData *resources) {
    if (!resources) {
        return -1;
    }

    char path[MAX_PATH_LEN];
    FILE *file;

    // Read /proc/[pid]/stat for CPU
    snprintf(path, sizeof(path), "/proc/%d/stat", pid);
    file = fopen(path, "r");
    if (file) {
        unsigned long utime, stime;
        // Format: pid (comm) state ppid ... utime stime ...
        int scanned = fscanf(file, "%*d %*s %*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu",
                            &utime, &stime);
        fclose(file);
        
        if (scanned == 2) {
            // Simple CPU calculation (would need more sophisticated tracking for accurate %)
            resources->cpu_percent = 0.0; // Placeholder - would need previous values
        }
    }

    // Read /proc/[pid]/status for memory
    snprintf(path, sizeof(path), "/proc/%d/status", pid);
    file = fopen(path, "r");
    if (file) {
        char line[MAX_LINE_LEN];
        while (fgets(line, sizeof(line), file)) {
            if (strncmp(line, "VmRSS:", 6) == 0) {
                unsigned long rss_kb;
                sscanf(line, "VmRSS: %lu", &rss_kb);
                resources->memory_rss = rss_kb * 1024; // Convert to bytes
            } else if (strncmp(line, "VmSize:", 7) == 0) {
                unsigned long vsize_kb;
                sscanf(line, "VmSize: %lu", &vsize_kb);
                resources->memory_vsize = vsize_kb * 1024; // Convert to bytes
            }
        }
        fclose(file);
    }

    return 0;
}

BehaviorSnapshot* behavior_collector_collect(pid_t pid) {
    BehaviorSnapshot *snapshot = calloc(1, sizeof(BehaviorSnapshot));
    if (!snapshot) {
        return NULL;
    }

    snapshot->pid = pid;
    snapshot->timestamp = time(NULL);

    // Initialize network data
    snapshot->network.count = 0;
    snapshot->network.capacity = 0;
    snapshot->network.connections = NULL;

    // Initialize file data
    snapshot->files.count = 0;
    snapshot->files.capacity = 0;
    snapshot->files.files = NULL;

    behavior_collector_get_network(pid, &snapshot->network);
    behavior_collector_get_files(pid, &snapshot->files);
    behavior_collector_get_resources(pid, &snapshot->resources);

    return snapshot;
}

void behavior_snapshot_free(BehaviorSnapshot *snapshot) {
    if (!snapshot) {
        return;
    }

    // Free network connections
    if (snapshot->network.connections) {
        for (size_t i = 0; i < snapshot->network.count; i++) {
            free(snapshot->network.connections[i].dest_ip);
        }
        free(snapshot->network.connections);
    }

    // Free file list
    if (snapshot->files.files) {
        for (size_t i = 0; i < snapshot->files.count; i++) {
            free(snapshot->files.files[i]);
        }
        free(snapshot->files.files);
    }

    free(snapshot);
}

