package main

import (
	"bufio"
	"context"
	"encoding/json"
	"fmt"
	"os"
	"os/exec"
	"path/filepath"
	"sort"
	"strconv"
	"strings"

	"golang.org/x/sync/errgroup"
)

type CPUInfo struct {
	CPU    int     `json:"cpu"`
	MaxMHz float64 `json:"maxmhz"`
}

type LSCPUOutput struct {
	CPUs []CPUInfo `json:"cpus"`
}

type CoreFreq struct {
	Core int
	Freq int
}

func main() {
	args := os.Args[1:]
	if len(args) == 0 {
		fmt.Println("Please provide a command to run")
		os.Exit(1)
	}

	eg, ctx := errgroup.WithContext(context.Background())
	canPerf, cores, err := getBestCores()
	if err != nil {
		eg.Go(func() error {
			return runCommand(ctx, false, []int{}, args)
		})
		err = eg.Wait()
		if err != nil {
			fmt.Println(err)
			os.Exit(1)
		}
		os.Exit(0)
	}

	eg.Go(func() error {
		return runCommand(ctx, canPerf, cores, args)
	})
	err = eg.Wait()
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	os.Exit(0)
}

func runCommand(ctx context.Context, canPerf bool, cores []int, args []string) error {
	commandToRun := make([]string, 0)
	if canPerf {
		commandToRun = append(commandToRun, "powerprofilesctl", "launch", "-p", "performance")
	}

	if len(cores) > 0 {
		coreVals := make([]string, len(cores))
		for i, core := range cores {
			coreVals[i] = fmt.Sprint(core)
		}

		commandToRun = append(commandToRun, "taskset", "-c", strings.Join(coreVals, ","))
	}

	commandToRun = append(commandToRun, args...)

	cmd := exec.CommandContext(ctx, commandToRun[0], commandToRun[1:]...)
	cmd.Env = append(cmd.Environ(), "POSIXLY_CORRECT=1")
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Stdin = os.Stdin

	return cmd.Run()
}

func getBestCores() (bool, []int, error) {
	perfMode, amdPrefCores, amdFlag, highestCores, err := GetPPStatusAndHighestCores()
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}

	if !amdFlag {
		return perfMode, highestCores, nil
	}

	numAmdCores := len(amdPrefCores)
	switch numAmdCores {
	case 32:
		return perfMode, getHighestCores(16, amdPrefCores), nil
	case 24:
		return perfMode, getHighestCores(12, amdPrefCores), nil
	case 16:
		return perfMode, getHighestCores(8, amdPrefCores), nil
	default:
		return perfMode, getHighestCores(numAmdCores, amdPrefCores), nil
	}
}

func getHighestCores(n int, coreFreqs []CoreFreq) []int {
	if n <= 0 {
		return []int{}
	}

	sort.Slice(coreFreqs, func(i, j int) bool {
		return coreFreqs[i].Freq > coreFreqs[j].Freq
	})

	result := make([]int, 0, n)
	for i := 0; i < n && i < len(coreFreqs); i++ {
		result = append(result, coreFreqs[i].Core)
	}
	sort.Stable(sort.IntSlice(result))

	return result
}

func GetPPStatusAndHighestCores() (bool, []CoreFreq, bool, []int, error) {
	var (
		performanceMode bool
		amdPstateFreqs  []CoreFreq
		amdFlag         bool
		highestCores    []int
	)

	g, ctx := errgroup.WithContext(context.Background())
	g.Go(func() error {
		cmd := exec.CommandContext(ctx, "powerprofilesctl", "list")
		output, err := cmd.Output()
		if err != nil {
			performanceMode = false
			return nil
		}

		performanceMode = strings.Contains(string(output), "performance")
		return nil
	})

	g.Go(func() error {
		pattern := "/sys/devices/system/cpu/cpu*/cpufreq/amd_pstate_prefcore_ranking"
		matches, err := filepath.Glob(pattern)
		if err != nil {
			return fmt.Errorf("failed to glob AMD pstate files: %w", err)
		}

		if len(matches) > 0 {
			amdFlag = true
		}

		for _, match := range matches {
			file, err := os.Open(match)
			if err != nil {
				return fmt.Errorf("failed to open file %s: %w", match, err)
			}
			defer file.Close()

			parts := strings.Split(match, "/")
			cpuPart := parts[len(parts)-3]
			coreNum, err := strconv.Atoi(strings.TrimPrefix(cpuPart, "cpu"))
			if err != nil {
				return fmt.Errorf("failed to parse core number from path %s: %w", match, err)
			}

			scanner := bufio.NewScanner(file)
			if scanner.Scan() {
				value, err := strconv.Atoi(strings.TrimSpace(scanner.Text()))
				if err != nil {
					return fmt.Errorf("failed to parse value from file %s: %w", match, err)
				}

				amdPstateFreqs = append(amdPstateFreqs, CoreFreq{Core: coreNum, Freq: value})
			}

			if err := scanner.Err(); err != nil {
				return fmt.Errorf("error reading file %s: %w", match, err)
			}
		}
		return nil
	})

	g.Go(func() error {
		cmd := exec.CommandContext(ctx, "lscpu", "-e", "-J")
		cmd.Env = append(cmd.Env, "LANG=en_GB.utf8")
		output, err := cmd.Output()
		if err != nil {
			return fmt.Errorf("failed to run lscpu command: %w", err)
		}

		var lscpuOutput LSCPUOutput
		if err := json.Unmarshal(output, &lscpuOutput); err != nil {
			return fmt.Errorf("failed to unmarshal lscpu output: %w", err)
		}

		sort.Slice(lscpuOutput.CPUs, func(i, j int) bool {
			return lscpuOutput.CPUs[i].MaxMHz > lscpuOutput.CPUs[j].MaxMHz
		})

		if len(lscpuOutput.CPUs) > 0 {
			topMHz := lscpuOutput.CPUs[0].MaxMHz
			secondMHz := topMHz

			for _, cpu := range lscpuOutput.CPUs {
				if cpu.MaxMHz < topMHz {
					secondMHz = cpu.MaxMHz
					break
				}
			}

			for _, cpu := range lscpuOutput.CPUs {
				if cpu.MaxMHz == topMHz || cpu.MaxMHz == secondMHz {
					highestCores = append(highestCores, cpu.CPU)
				}
			}
		}

		return nil
	})

	if err := g.Wait(); err != nil {
		return false, nil, false, nil, err
	}

	return performanceMode, amdPstateFreqs, amdFlag, highestCores, nil
}