#!/usr/bin/env python3

# Assignment 5
# CMPU 366, Fall 2025

import csv
from pathlib import Path
from typing import Callable, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DistilBertModel, DistilBertTokenizer

torch.manual_seed(0)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert = DistilBertModel.from_pretrained("distilbert-base-uncased").to(device)


EPOCHS = 5
BATCH_SIZE = 24
MAX_LENGTH = 10
LR = 0.001
CKPT_DIR = "./ckpt"
NUM_CLASSES = 3


label_map = {"Beyoncé": 0, "Drake": 1, "Taylor Swift": 2}
label_map_rev = {0: "Beyoncé", 1: "Drake", 2: "Taylor Swift"}


class NN(nn.Module):
    def __init__(self, n_features: int):
        """Construct the pieces of the neural network."""

        # Initialize the parent class (nn.Module)
        super(NN, self).__init__()

        self.layer1 = nn.Linear(n_features, n_features)
        self.layer2 = nn.Linear(n_features, n_features // 2)

        # Reduce to the three labels (artist names)
        self.output_layer = nn.Linear(n_features // 2, 3)

        # Log probabilities of each class (specifying the dimension of the
        # input tensor to use)
        self.out = nn.LogSoftmax(dim=1)

    def forward(self, x):
        """The forward pass of the model: Transform the input data x into
        the output predictions (the log probabilities for each label).
        """
        output1 = self.layer1(x)
        output2 = self.layer2(output1)
        scores = self.output_layer(output2)
        probs = self.out(scores)
        return probs


####


def make_data(fname: str, label_map: dict) -> Tuple[list[str], list[int]]:
    ...


def prep_bert_data(
    data: list[str], max_length: int
) -> list[torch.Tensor]:
    ...


####


def get_predicted_label_from_predictions(predictions):
    predicted_label = predictions.argmax(1).item()
    return predicted_label


def sample_and_print_predictions(feats, data, labels, model):
    ...


####


def train(dataloader, model, optimizer, epoch: int):
    """Run an epoch of training the model on the provided data, using the
    specified optimizer.
    """
    loss_fn = nn.NLLLoss()
    model.train()
    with tqdm(dataloader, unit="batch") as tbatch:
        for X, y in tbatch:
            X = X.to(device)
            y = y.to(device)

            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss,
        },
        f"{CKPT_DIR}/ckpt_{epoch}.pt",
    )


def predict(data, model):
    predictions = []
    dataloader = DataLoader(data, batch_size=1)
    with torch.no_grad():
        for X in dataloader:
            X = X.to(device)
            pred = model(X)
            predictions.append(pred)
    return predictions


def test(dataloader, model, dataset_name):
    loss_fn = nn.NLLLoss()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"{dataset_name} Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>6f}\n"
    )


####


def make_or_restore_model(
    nfeat: int,
) -> Tuple[nn.Module, torch.optim.Optimizer, int]:
    """Either restore the latest model, or create a fresh one"""
    model = NN(nfeat).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
    ckpt_path = Path(CKPT_DIR)
    checkpoints = [p for p in ckpt_path.glob("*.pt")]

    if checkpoints:
        latest_checkpoint = max(
            checkpoints,
            key=lambda p: int(p.stem.split('_')[1])
        )
        print("Restoring from", latest_checkpoint)
        ckpt = torch.load(latest_checkpoint)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        epoch = ckpt["epoch"]
        return model, optimizer, epoch + 1
    else:
        print("Creating a new model")
        return model, optimizer, 0


####


def main():
    """Run the song classification."""

    Path(CKPT_DIR).mkdir(exist_ok=True)

    train_f = "train.csv"
    test_f = "test.csv"

    # train_data, train_labels = make_data(train_f, label_map)
    # test_data, test_labels = make_data(test_f, label_map)

    # for i in label_map_rev:
    #     print(f"Lyrics in Class {i} ({label_map_rev[i] + '):':14}",
    #           len([t for t in train_labels if t == i]))

    # print()

    # train_feats = prep_bert_data(train_data, MAX_LENGTH)
    # test_feats = prep_bert_data(test_data, MAX_LENGTH)

    # train_dataset = list(zip(train_feats, train_labels))
    # test_dataset = list(zip(test_feats, test_labels))

    # train_dataloader = DataLoader(
    #     train_dataset, batch_size=BATCH_SIZE, shuffle=True
    # )
    # test_dataloader = DataLoader(test_dataset, batch_size=1)

    # model, optimizer, epoch_start = make_or_restore_model(MAX_LENGTH)

    # for e in range(epoch_start, EPOCHS):
    #     print()
    #     print("Epoch", e)
    #     print("-------")
    #
    #     model.train()
    #     train(train_dataloader, model, optimizer, e)
    #
    #     print()
    #
    #     model.eval()
    #     test(train_dataloader, model, "Train")
    #     test(test_dataloader, model, "Test")
    #
    # test_predictions = predict(test_feats, model)
    # print_performance_by_class(test_labels, test_predictions)
    # print()

    # sample_and_print_predictions(test_feats, test_data, test_labels,model)


if __name__ == "__main__":
    main()
