From 641b79ff59a836cf93bce532ffd9720090ce4c32 Mon Sep 17 00:00:00 2001 From: ferreo Date: Thu, 15 Aug 2024 19:21:16 +0100 Subject: [PATCH] Update patches and config --- config | 14 +- patches/0001-cachyos-base-all.patch | 54703 ++++++++++++++++++++++++++ patches/0002-sched-ext.patch | 15304 +++++++ patches/0003-bore-cachy-ext.patch | 990 + patches/series | 3 + 5 files changed, 71007 insertions(+), 7 deletions(-) create mode 100644 patches/0001-cachyos-base-all.patch create mode 100644 patches/0002-sched-ext.patch create mode 100644 patches/0003-bore-cachy-ext.patch create mode 100644 patches/series diff --git a/config b/config index 7780f29..a66f036 100644 --- a/config +++ b/config @@ -1,15 +1,15 @@ # # Automatically generated file; DO NOT EDIT. -# Linux/x86 6.10.0 Kernel Configuration +# Linux/x86 6.10.5 Kernel Configuration # -CONFIG_CC_VERSION_TEXT="gcc (GCC) 14.1.1 20240522" +CONFIG_CC_VERSION_TEXT="gcc (GCC) 14.2.1 20240805" CONFIG_CC_IS_GCC=y -CONFIG_GCC_VERSION=140101 +CONFIG_GCC_VERSION=140201 CONFIG_CLANG_VERSION=0 CONFIG_AS_IS_GNU=y -CONFIG_AS_VERSION=24200 +CONFIG_AS_VERSION=24300 CONFIG_LD_IS_BFD=y -CONFIG_LD_VERSION=24200 +CONFIG_LD_VERSION=24300 CONFIG_LLD_VERSION=0 CONFIG_RUST_IS_AVAILABLE=y CONFIG_CC_CAN_LINK=y @@ -130,6 +130,7 @@ CONFIG_BPF_JIT_DEFAULT_ON=y CONFIG_BPF_UNPRIV_DEFAULT_OFF=y # CONFIG_BPF_PRELOAD is not set CONFIG_BPF_LSM=y +CONFIG_DEBUG_INFO_BTF=y # end of BPF subsystem CONFIG_PREEMPT_BUILD=y @@ -140,7 +141,7 @@ CONFIG_PREEMPT_COUNT=y CONFIG_PREEMPTION=y CONFIG_PREEMPT_DYNAMIC=y CONFIG_SCHED_CORE=y -# CONFIG_SCHED_CLASS_EXT is not set +CONFIG_SCHED_CLASS_EXT=y # # CPU/Task time and stats accounting @@ -230,7 +231,6 @@ CONFIG_CGROUP_SCHED=y CONFIG_FAIR_GROUP_SCHED=y CONFIG_CFS_BANDWIDTH=y # CONFIG_RT_GROUP_SCHED is not set -# CONFIG_EXT_GROUP_SCHED is not set CONFIG_SCHED_MM_CID=y CONFIG_UCLAMP_TASK_GROUP=y CONFIG_CGROUP_PIDS=y diff --git a/patches/0001-cachyos-base-all.patch b/patches/0001-cachyos-base-all.patch new file mode 100644 index 0000000..f8fb7aa --- /dev/null +++ b/patches/0001-cachyos-base-all.patch @@ -0,0 +1,54703 @@ +From d6a6d104d46aedec0a853aaeac60ab4bcce7c9d4 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:18:27 +0200 +Subject: [PATCH 01/12] amd-pstate + +Signed-off-by: Peter Jung +--- + Documentation/admin-guide/pm/amd-pstate.rst | 18 +- + arch/x86/include/asm/cpufeatures.h | 1 + + arch/x86/include/asm/msr-index.h | 2 + + arch/x86/kernel/cpu/scattered.c | 1 + + drivers/cpufreq/Kconfig.x86 | 1 + + drivers/cpufreq/acpi-cpufreq.c | 3 +- + drivers/cpufreq/amd-pstate.c | 307 +++++++++++++++----- + drivers/cpufreq/amd-pstate.h | 1 + + drivers/cpufreq/cpufreq.c | 11 +- + 9 files changed, 260 insertions(+), 85 deletions(-) + +diff --git a/Documentation/admin-guide/pm/amd-pstate.rst b/Documentation/admin-guide/pm/amd-pstate.rst +index 1e0d101b020a..d0324d44f548 100644 +--- a/Documentation/admin-guide/pm/amd-pstate.rst ++++ b/Documentation/admin-guide/pm/amd-pstate.rst +@@ -281,6 +281,22 @@ integer values defined between 0 to 255 when EPP feature is enabled by platform + firmware, if EPP feature is disabled, driver will ignore the written value + This attribute is read-write. + ++``boost`` ++The `boost` sysfs attribute provides control over the CPU core ++performance boost, allowing users to manage the maximum frequency limitation ++of the CPU. This attribute can be used to enable or disable the boost feature ++on individual CPUs. ++ ++When the boost feature is enabled, the CPU can dynamically increase its frequency ++beyond the base frequency, providing enhanced performance for demanding workloads. ++On the other hand, disabling the boost feature restricts the CPU to operate at the ++base frequency, which may be desirable in certain scenarios to prioritize power ++efficiency or manage temperature. ++ ++To manipulate the `boost` attribute, users can write a value of `0` to disable the ++boost or `1` to enable it, for the respective CPU using the sysfs path ++`/sys/devices/system/cpu/cpuX/cpufreq/boost`, where `X` represents the CPU number. ++ + Other performance and frequency values can be read back from + ``/sys/devices/system/cpu/cpuX/acpi_cppc/``, see :ref:`cppc_sysfs`. + +@@ -406,7 +422,7 @@ control its functionality at the system level. They are located in the + ``/sys/devices/system/cpu/amd_pstate/`` directory and affect all CPUs. + + ``status`` +- Operation mode of the driver: "active", "passive" or "disable". ++ Operation mode of the driver: "active", "passive", "guided" or "disable". + + "active" + The driver is functional and in the ``active mode`` +diff --git a/arch/x86/include/asm/cpufeatures.h b/arch/x86/include/asm/cpufeatures.h +index 3c7434329661..6c128d463a14 100644 +--- a/arch/x86/include/asm/cpufeatures.h ++++ b/arch/x86/include/asm/cpufeatures.h +@@ -470,6 +470,7 @@ + #define X86_FEATURE_BHI_CTRL (21*32+ 2) /* "" BHI_DIS_S HW control available */ + #define X86_FEATURE_CLEAR_BHB_HW (21*32+ 3) /* "" BHI_DIS_S HW control enabled */ + #define X86_FEATURE_CLEAR_BHB_LOOP_ON_VMEXIT (21*32+ 4) /* "" Clear branch history at vmexit using SW loop */ ++#define X86_FEATURE_FAST_CPPC (21*32 + 5) /* "" AMD Fast CPPC */ + + /* + * BUG word(s) +diff --git a/arch/x86/include/asm/msr-index.h b/arch/x86/include/asm/msr-index.h +index e022e6eb766c..384739d592af 100644 +--- a/arch/x86/include/asm/msr-index.h ++++ b/arch/x86/include/asm/msr-index.h +@@ -781,6 +781,8 @@ + #define MSR_K7_HWCR_IRPERF_EN BIT_ULL(MSR_K7_HWCR_IRPERF_EN_BIT) + #define MSR_K7_FID_VID_CTL 0xc0010041 + #define MSR_K7_FID_VID_STATUS 0xc0010042 ++#define MSR_K7_HWCR_CPB_DIS_BIT 25 ++#define MSR_K7_HWCR_CPB_DIS BIT_ULL(MSR_K7_HWCR_CPB_DIS_BIT) + + /* K6 MSRs */ + #define MSR_K6_WHCR 0xc0000082 +diff --git a/arch/x86/kernel/cpu/scattered.c b/arch/x86/kernel/cpu/scattered.c +index af5aa2c754c2..c84c30188fdf 100644 +--- a/arch/x86/kernel/cpu/scattered.c ++++ b/arch/x86/kernel/cpu/scattered.c +@@ -45,6 +45,7 @@ static const struct cpuid_bit cpuid_bits[] = { + { X86_FEATURE_HW_PSTATE, CPUID_EDX, 7, 0x80000007, 0 }, + { X86_FEATURE_CPB, CPUID_EDX, 9, 0x80000007, 0 }, + { X86_FEATURE_PROC_FEEDBACK, CPUID_EDX, 11, 0x80000007, 0 }, ++ { X86_FEATURE_FAST_CPPC, CPUID_EDX, 15, 0x80000007, 0 }, + { X86_FEATURE_MBA, CPUID_EBX, 6, 0x80000008, 0 }, + { X86_FEATURE_SMBA, CPUID_EBX, 2, 0x80000020, 0 }, + { X86_FEATURE_BMEC, CPUID_EBX, 3, 0x80000020, 0 }, +diff --git a/drivers/cpufreq/Kconfig.x86 b/drivers/cpufreq/Kconfig.x86 +index 438c9e75a04d..97c2d4f15d76 100644 +--- a/drivers/cpufreq/Kconfig.x86 ++++ b/drivers/cpufreq/Kconfig.x86 +@@ -71,6 +71,7 @@ config X86_AMD_PSTATE_DEFAULT_MODE + config X86_AMD_PSTATE_UT + tristate "selftest for AMD Processor P-State driver" + depends on X86 && ACPI_PROCESSOR ++ depends on X86_AMD_PSTATE + default n + help + This kernel module is used for testing. It's safe to say M here. +diff --git a/drivers/cpufreq/acpi-cpufreq.c b/drivers/cpufreq/acpi-cpufreq.c +index 4ac3a35dcd98..f4f8587c4ea0 100644 +--- a/drivers/cpufreq/acpi-cpufreq.c ++++ b/drivers/cpufreq/acpi-cpufreq.c +@@ -50,8 +50,6 @@ enum { + #define AMD_MSR_RANGE (0x7) + #define HYGON_MSR_RANGE (0x7) + +-#define MSR_K7_HWCR_CPB_DIS (1ULL << 25) +- + struct acpi_cpufreq_data { + unsigned int resume; + unsigned int cpu_feature; +@@ -139,6 +137,7 @@ static int set_boost(struct cpufreq_policy *policy, int val) + (void *)(long)val, 1); + pr_debug("CPU %*pbl: Core Boosting %s.\n", + cpumask_pr_args(policy->cpus), str_enabled_disabled(val)); ++ policy->boost_enabled = val; + + return 0; + } +diff --git a/drivers/cpufreq/amd-pstate.c b/drivers/cpufreq/amd-pstate.c +index 67c4a6a0ef12..d7c4a7d6d993 100644 +--- a/drivers/cpufreq/amd-pstate.c ++++ b/drivers/cpufreq/amd-pstate.c +@@ -51,6 +51,7 @@ + + #define AMD_PSTATE_TRANSITION_LATENCY 20000 + #define AMD_PSTATE_TRANSITION_DELAY 1000 ++#define AMD_PSTATE_FAST_CPPC_TRANSITION_DELAY 600 + #define CPPC_HIGHEST_PERF_PERFORMANCE 196 + #define CPPC_HIGHEST_PERF_DEFAULT 166 + +@@ -85,15 +86,6 @@ struct quirk_entry { + u32 lowest_freq; + }; + +-/* +- * TODO: We need more time to fine tune processors with shared memory solution +- * with community together. +- * +- * There are some performance drops on the CPU benchmarks which reports from +- * Suse. We are co-working with them to fine tune the shared memory solution. So +- * we disable it by default to go acpi-cpufreq on these processors and add a +- * module parameter to be able to enable it manually for debugging. +- */ + static struct cpufreq_driver *current_pstate_driver; + static struct cpufreq_driver amd_pstate_driver; + static struct cpufreq_driver amd_pstate_epp_driver; +@@ -150,6 +142,11 @@ static struct quirk_entry quirk_amd_7k62 = { + .lowest_freq = 550, + }; + ++static struct quirk_entry quirk_amd_mts = { ++ .nominal_freq = 3600, ++ .lowest_freq = 550, ++}; ++ + static int __init dmi_matched_7k62_bios_bug(const struct dmi_system_id *dmi) + { + /** +@@ -157,7 +154,7 @@ static int __init dmi_matched_7k62_bios_bug(const struct dmi_system_id *dmi) + * broken BIOS lack of nominal_freq and lowest_freq capabilities + * definition in ACPI tables + */ +- if (boot_cpu_has(X86_FEATURE_ZEN2)) { ++ if (cpu_feature_enabled(X86_FEATURE_ZEN2)) { + quirks = dmi->driver_data; + pr_info("Overriding nominal and lowest frequencies for %s\n", dmi->ident); + return 1; +@@ -166,6 +163,21 @@ static int __init dmi_matched_7k62_bios_bug(const struct dmi_system_id *dmi) + return 0; + } + ++static int __init dmi_matched_mts_bios_bug(const struct dmi_system_id *dmi) ++{ ++ /** ++ * match the broken bios for ryzen 3000 series processor support CPPC V2 ++ * broken BIOS lack of nominal_freq and lowest_freq capabilities ++ * definition in ACPI tables ++ */ ++ if (cpu_feature_enabled(X86_FEATURE_ZEN2)) { ++ quirks = dmi->driver_data; ++ pr_info("Overriding nominal and lowest frequencies for %s\n", dmi->ident); ++ return 1; ++ } ++ ++ return 0; ++} + static const struct dmi_system_id amd_pstate_quirks_table[] __initconst = { + { + .callback = dmi_matched_7k62_bios_bug, +@@ -176,6 +188,16 @@ static const struct dmi_system_id amd_pstate_quirks_table[] __initconst = { + }, + .driver_data = &quirk_amd_7k62, + }, ++ { ++ .callback = dmi_matched_mts_bios_bug, ++ .ident = "AMD Ryzen 3000", ++ .matches = { ++ DMI_MATCH(DMI_PRODUCT_NAME, "B450M MORTAR MAX (MS-7B89)"), ++ DMI_MATCH(DMI_BIOS_RELEASE, "06/10/2020"), ++ DMI_MATCH(DMI_BIOS_VERSION, "5.14"), ++ }, ++ .driver_data = &quirk_amd_mts, ++ }, + {} + }; + MODULE_DEVICE_TABLE(dmi, amd_pstate_quirks_table); +@@ -199,7 +221,7 @@ static s16 amd_pstate_get_epp(struct amd_cpudata *cpudata, u64 cppc_req_cached) + u64 epp; + int ret; + +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + if (!cppc_req_cached) { + epp = rdmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, + &cppc_req_cached); +@@ -272,7 +294,7 @@ static int amd_pstate_set_epp(struct amd_cpudata *cpudata, u32 epp) + int ret; + struct cppc_perf_ctrls perf_ctrls; + +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + u64 value = READ_ONCE(cpudata->cppc_req_cached); + + value &= ~GENMASK_ULL(31, 24); +@@ -329,7 +351,7 @@ static inline int pstate_enable(bool enable) + return 0; + + for_each_present_cpu(cpu) { +- unsigned long logical_id = topology_logical_die_id(cpu); ++ unsigned long logical_id = topology_logical_package_id(cpu); + + if (test_bit(logical_id, &logical_proc_id_mask)) + continue; +@@ -522,7 +544,10 @@ static inline bool amd_pstate_sample(struct amd_cpudata *cpudata) + static void amd_pstate_update(struct amd_cpudata *cpudata, u32 min_perf, + u32 des_perf, u32 max_perf, bool fast_switch, int gov_flags) + { ++ unsigned long max_freq; ++ struct cpufreq_policy *policy = cpufreq_cpu_get(cpudata->cpu); + u64 prev = READ_ONCE(cpudata->cppc_req_cached); ++ u32 nominal_perf = READ_ONCE(cpudata->nominal_perf); + u64 value = prev; + + min_perf = clamp_t(unsigned long, min_perf, cpudata->min_limit_perf, +@@ -531,6 +556,9 @@ static void amd_pstate_update(struct amd_cpudata *cpudata, u32 min_perf, + cpudata->max_limit_perf); + des_perf = clamp_t(unsigned long, des_perf, min_perf, max_perf); + ++ max_freq = READ_ONCE(cpudata->max_limit_freq); ++ policy->cur = div_u64(des_perf * max_freq, max_perf); ++ + if ((cppc_state == AMD_PSTATE_GUIDED) && (gov_flags & CPUFREQ_GOV_DYNAMIC_SWITCHING)) { + min_perf = des_perf; + des_perf = 0; +@@ -542,6 +570,10 @@ static void amd_pstate_update(struct amd_cpudata *cpudata, u32 min_perf, + value &= ~AMD_CPPC_DES_PERF(~0L); + value |= AMD_CPPC_DES_PERF(des_perf); + ++ /* limit the max perf when core performance boost feature is disabled */ ++ if (!cpudata->boost_supported) ++ max_perf = min_t(unsigned long, nominal_perf, max_perf); ++ + value &= ~AMD_CPPC_MAX_PERF(~0L); + value |= AMD_CPPC_MAX_PERF(max_perf); + +@@ -652,10 +684,9 @@ static void amd_pstate_adjust_perf(unsigned int cpu, + unsigned long capacity) + { + unsigned long max_perf, min_perf, des_perf, +- cap_perf, lowest_nonlinear_perf, max_freq; ++ cap_perf, lowest_nonlinear_perf; + struct cpufreq_policy *policy = cpufreq_cpu_get(cpu); + struct amd_cpudata *cpudata = policy->driver_data; +- unsigned int target_freq; + + if (policy->min != cpudata->min_limit_freq || policy->max != cpudata->max_limit_freq) + amd_pstate_update_min_max_limit(policy); +@@ -663,7 +694,6 @@ static void amd_pstate_adjust_perf(unsigned int cpu, + + cap_perf = READ_ONCE(cpudata->highest_perf); + lowest_nonlinear_perf = READ_ONCE(cpudata->lowest_nonlinear_perf); +- max_freq = READ_ONCE(cpudata->max_freq); + + des_perf = cap_perf; + if (target_perf < capacity) +@@ -681,51 +711,111 @@ static void amd_pstate_adjust_perf(unsigned int cpu, + max_perf = min_perf; + + des_perf = clamp_t(unsigned long, des_perf, min_perf, max_perf); +- target_freq = div_u64(des_perf * max_freq, max_perf); +- policy->cur = target_freq; + + amd_pstate_update(cpudata, min_perf, des_perf, max_perf, true, + policy->governor->flags); + cpufreq_cpu_put(policy); + } + +-static int amd_pstate_set_boost(struct cpufreq_policy *policy, int state) ++static int amd_pstate_cpu_boost_update(struct cpufreq_policy *policy, bool on) + { + struct amd_cpudata *cpudata = policy->driver_data; +- int ret; ++ struct cppc_perf_ctrls perf_ctrls; ++ u32 highest_perf, nominal_perf, nominal_freq, max_freq; ++ int ret = 0; + +- if (!cpudata->boost_supported) { +- pr_err("Boost mode is not supported by this processor or SBIOS\n"); +- return -EINVAL; ++ highest_perf = READ_ONCE(cpudata->highest_perf); ++ nominal_perf = READ_ONCE(cpudata->nominal_perf); ++ nominal_freq = READ_ONCE(cpudata->nominal_freq); ++ max_freq = READ_ONCE(cpudata->max_freq); ++ ++ if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ u64 value = READ_ONCE(cpudata->cppc_req_cached); ++ ++ value &= ~GENMASK_ULL(7, 0); ++ value |= on ? highest_perf : nominal_perf; ++ WRITE_ONCE(cpudata->cppc_req_cached, value); ++ ++ wrmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, value); ++ } else { ++ perf_ctrls.max_perf = on ? highest_perf : nominal_perf; ++ ret = cppc_set_perf(cpudata->cpu, &perf_ctrls); ++ if (ret) { ++ cpufreq_cpu_release(policy); ++ pr_debug("Failed to set max perf on CPU:%d. ret:%d\n", ++ cpudata->cpu, ret); ++ return ret; ++ } + } + +- if (state) +- policy->cpuinfo.max_freq = cpudata->max_freq; +- else +- policy->cpuinfo.max_freq = cpudata->nominal_freq * 1000; ++ if (on) ++ policy->cpuinfo.max_freq = max_freq; ++ else if (policy->cpuinfo.max_freq > nominal_freq * 1000) ++ policy->cpuinfo.max_freq = nominal_freq * 1000; + + policy->max = policy->cpuinfo.max_freq; + +- ret = freq_qos_update_request(&cpudata->req[1], +- policy->cpuinfo.max_freq); +- if (ret < 0) +- return ret; ++ if (cppc_state == AMD_PSTATE_PASSIVE) { ++ ret = freq_qos_update_request(&cpudata->req[1], policy->cpuinfo.max_freq); ++ if (ret < 0) ++ pr_debug("Failed to update freq constraint: CPU%d\n", cpudata->cpu); ++ } + +- return 0; ++ return ret < 0 ? ret : 0; + } + +-static void amd_pstate_boost_init(struct amd_cpudata *cpudata) ++static int amd_pstate_set_boost(struct cpufreq_policy *policy, int state) + { +- u32 highest_perf, nominal_perf; ++ struct amd_cpudata *cpudata = policy->driver_data; ++ int ret; + +- highest_perf = READ_ONCE(cpudata->highest_perf); +- nominal_perf = READ_ONCE(cpudata->nominal_perf); ++ if (!cpudata->boost_supported) { ++ pr_err("Boost mode is not supported by this processor or SBIOS\n"); ++ return -EOPNOTSUPP; ++ } ++ mutex_lock(&amd_pstate_driver_lock); ++ ret = amd_pstate_cpu_boost_update(policy, state); ++ WRITE_ONCE(cpudata->boost_state, !ret ? state : false); ++ policy->boost_enabled = !ret ? state : false; ++ refresh_frequency_limits(policy); ++ mutex_unlock(&amd_pstate_driver_lock); + +- if (highest_perf <= nominal_perf) +- return; ++ return ret; ++} + +- cpudata->boost_supported = true; ++static int amd_pstate_init_boost_support(struct amd_cpudata *cpudata) ++{ ++ u64 boost_val; ++ int ret = -1; ++ ++ /* ++ * If platform has no CPB support or disable it, initialize current driver ++ * boost_enabled state to be false, it is not an error for cpufreq core to handle. ++ */ ++ if (!cpu_feature_enabled(X86_FEATURE_CPB)) { ++ pr_debug_once("Boost CPB capabilities not present in the processor\n"); ++ ret = 0; ++ goto exit_err; ++ } ++ ++ /* at least one CPU supports CPB, even if others fail later on to set up */ + current_pstate_driver->boost_enabled = true; ++ ++ ret = rdmsrl_on_cpu(cpudata->cpu, MSR_K7_HWCR, &boost_val); ++ if (ret) { ++ pr_err_once("failed to read initial CPU boost state!\n"); ++ ret = -EIO; ++ goto exit_err; ++ } ++ ++ if (!(boost_val & MSR_K7_HWCR_CPB_DIS)) ++ cpudata->boost_supported = true; ++ ++ return 0; ++ ++exit_err: ++ cpudata->boost_supported = false; ++ return ret; + } + + static void amd_perf_ctl_reset(unsigned int cpu) +@@ -754,7 +844,7 @@ static int amd_pstate_get_highest_perf(int cpu, u32 *highest_perf) + { + int ret; + +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + u64 cap1; + + ret = rdmsrl_safe_on_cpu(cpu, MSR_AMD_CPPC_CAP1, &cap1); +@@ -850,8 +940,12 @@ static u32 amd_pstate_get_transition_delay_us(unsigned int cpu) + u32 transition_delay_ns; + + transition_delay_ns = cppc_get_transition_latency(cpu); +- if (transition_delay_ns == CPUFREQ_ETERNAL) +- return AMD_PSTATE_TRANSITION_DELAY; ++ if (transition_delay_ns == CPUFREQ_ETERNAL) { ++ if (cpu_feature_enabled(X86_FEATURE_FAST_CPPC)) ++ return AMD_PSTATE_FAST_CPPC_TRANSITION_DELAY; ++ else ++ return AMD_PSTATE_TRANSITION_DELAY; ++ } + + return transition_delay_ns / NSEC_PER_USEC; + } +@@ -922,12 +1016,30 @@ static int amd_pstate_init_freq(struct amd_cpudata *cpudata) + WRITE_ONCE(cpudata->nominal_freq, nominal_freq); + WRITE_ONCE(cpudata->max_freq, max_freq); + ++ /** ++ * Below values need to be initialized correctly, otherwise driver will fail to load ++ * max_freq is calculated according to (nominal_freq * highest_perf)/nominal_perf ++ * lowest_nonlinear_freq is a value between [min_freq, nominal_freq] ++ * Check _CPC in ACPI table objects if any values are incorrect ++ */ ++ if (min_freq <= 0 || max_freq <= 0 || nominal_freq <= 0 || min_freq > max_freq) { ++ pr_err("min_freq(%d) or max_freq(%d) or nominal_freq(%d) value is incorrect\n", ++ min_freq, max_freq, nominal_freq * 1000); ++ return -EINVAL; ++ } ++ ++ if (lowest_nonlinear_freq <= min_freq || lowest_nonlinear_freq > nominal_freq * 1000) { ++ pr_err("lowest_nonlinear_freq(%d) value is out of range [min_freq(%d), nominal_freq(%d)]\n", ++ lowest_nonlinear_freq, min_freq, nominal_freq * 1000); ++ return -EINVAL; ++ } ++ + return 0; + } + + static int amd_pstate_cpu_init(struct cpufreq_policy *policy) + { +- int min_freq, max_freq, nominal_freq, ret; ++ int min_freq, max_freq, ret; + struct device *dev; + struct amd_cpudata *cpudata; + +@@ -956,18 +1068,12 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) + if (ret) + goto free_cpudata1; + ++ ret = amd_pstate_init_boost_support(cpudata); ++ if (ret) ++ goto free_cpudata1; ++ + min_freq = READ_ONCE(cpudata->min_freq); + max_freq = READ_ONCE(cpudata->max_freq); +- nominal_freq = READ_ONCE(cpudata->nominal_freq); +- +- if (min_freq <= 0 || max_freq <= 0 || +- nominal_freq <= 0 || min_freq > max_freq) { +- dev_err(dev, +- "min_freq(%d) or max_freq(%d) or nominal_freq (%d) value is incorrect, check _CPC in ACPI tables\n", +- min_freq, max_freq, nominal_freq); +- ret = -EINVAL; +- goto free_cpudata1; +- } + + policy->cpuinfo.transition_latency = amd_pstate_get_transition_latency(policy->cpu); + policy->transition_delay_us = amd_pstate_get_transition_delay_us(policy->cpu); +@@ -978,10 +1084,12 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) + policy->cpuinfo.min_freq = min_freq; + policy->cpuinfo.max_freq = max_freq; + ++ policy->boost_enabled = READ_ONCE(cpudata->boost_supported); ++ + /* It will be updated by governor */ + policy->cur = policy->cpuinfo.min_freq; + +- if (boot_cpu_has(X86_FEATURE_CPPC)) ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) + policy->fast_switch_possible = true; + + ret = freq_qos_add_request(&policy->constraints, &cpudata->req[0], +@@ -1003,7 +1111,6 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) + + policy->driver_data = cpudata; + +- amd_pstate_boost_init(cpudata); + if (!current_pstate_driver->adjust_perf) + current_pstate_driver->adjust_perf = amd_pstate_adjust_perf; + +@@ -1214,7 +1321,7 @@ static int amd_pstate_change_mode_without_dvr_change(int mode) + + cppc_state = mode; + +- if (boot_cpu_has(X86_FEATURE_CPPC) || cppc_state == AMD_PSTATE_ACTIVE) ++ if (cpu_feature_enabled(X86_FEATURE_CPPC) || cppc_state == AMD_PSTATE_ACTIVE) + return 0; + + for_each_present_cpu(cpu) { +@@ -1387,7 +1494,7 @@ static bool amd_pstate_acpi_pm_profile_undefined(void) + + static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) + { +- int min_freq, max_freq, nominal_freq, ret; ++ int min_freq, max_freq, ret; + struct amd_cpudata *cpudata; + struct device *dev; + u64 value; +@@ -1418,17 +1525,12 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) + if (ret) + goto free_cpudata1; + ++ ret = amd_pstate_init_boost_support(cpudata); ++ if (ret) ++ goto free_cpudata1; ++ + min_freq = READ_ONCE(cpudata->min_freq); + max_freq = READ_ONCE(cpudata->max_freq); +- nominal_freq = READ_ONCE(cpudata->nominal_freq); +- if (min_freq <= 0 || max_freq <= 0 || +- nominal_freq <= 0 || min_freq > max_freq) { +- dev_err(dev, +- "min_freq(%d) or max_freq(%d) or nominal_freq(%d) value is incorrect, check _CPC in ACPI tables\n", +- min_freq, max_freq, nominal_freq); +- ret = -EINVAL; +- goto free_cpudata1; +- } + + policy->cpuinfo.min_freq = min_freq; + policy->cpuinfo.max_freq = max_freq; +@@ -1442,6 +1544,8 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) + policy->min = policy->cpuinfo.min_freq; + policy->max = policy->cpuinfo.max_freq; + ++ policy->boost_enabled = READ_ONCE(cpudata->boost_supported); ++ + /* + * Set the policy to provide a valid fallback value in case + * the default cpufreq governor is neither powersave nor performance. +@@ -1452,7 +1556,7 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) + else + policy->policy = CPUFREQ_POLICY_POWERSAVE; + +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + ret = rdmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, &value); + if (ret) + return ret; +@@ -1463,7 +1567,6 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) + return ret; + WRITE_ONCE(cpudata->cppc_cap1_cached, value); + } +- amd_pstate_boost_init(cpudata); + + return 0; + +@@ -1542,7 +1645,7 @@ static void amd_pstate_epp_update_limit(struct cpufreq_policy *policy) + epp = 0; + + /* Set initial EPP value */ +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + value &= ~GENMASK_ULL(31, 24); + value |= (u64)epp << 24; + } +@@ -1565,6 +1668,12 @@ static int amd_pstate_epp_set_policy(struct cpufreq_policy *policy) + + amd_pstate_epp_update_limit(policy); + ++ /* ++ * policy->cur is never updated with the amd_pstate_epp driver, but it ++ * is used as a stale frequency value. So, keep it within limits. ++ */ ++ policy->cur = policy->min; ++ + return 0; + } + +@@ -1581,7 +1690,7 @@ static void amd_pstate_epp_reenable(struct amd_cpudata *cpudata) + value = READ_ONCE(cpudata->cppc_req_cached); + max_perf = READ_ONCE(cpudata->highest_perf); + +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + wrmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, value); + } else { + perf_ctrls.max_perf = max_perf; +@@ -1615,7 +1724,7 @@ static void amd_pstate_epp_offline(struct cpufreq_policy *policy) + value = READ_ONCE(cpudata->cppc_req_cached); + + mutex_lock(&amd_pstate_limits_lock); +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + cpudata->epp_policy = CPUFREQ_POLICY_UNKNOWN; + + /* Set max perf same as min perf */ +@@ -1719,6 +1828,7 @@ static struct cpufreq_driver amd_pstate_epp_driver = { + .suspend = amd_pstate_epp_suspend, + .resume = amd_pstate_epp_resume, + .update_limits = amd_pstate_update_limits, ++ .set_boost = amd_pstate_set_boost, + .name = "amd-pstate-epp", + .attr = amd_pstate_epp_attr, + }; +@@ -1742,6 +1852,44 @@ static int __init amd_pstate_set_driver(int mode_idx) + return -EINVAL; + } + ++/** ++ * CPPC function is not supported for family ID 17H with model_ID ranging from 0x10 to 0x2F. ++ * show the debug message that helps to check if the CPU has CPPC support for loading issue. ++ */ ++static bool amd_cppc_supported(void) ++{ ++ struct cpuinfo_x86 *c = &cpu_data(0); ++ bool warn = false; ++ ++ if ((boot_cpu_data.x86 == 0x17) && (boot_cpu_data.x86_model < 0x30)) { ++ pr_debug_once("CPPC feature is not supported by the processor\n"); ++ return false; ++ } ++ ++ /* ++ * If the CPPC feature is disabled in the BIOS for processors that support MSR-based CPPC, ++ * the AMD Pstate driver may not function correctly. ++ * Check the CPPC flag and display a warning message if the platform supports CPPC. ++ * Note: below checking code will not abort the driver registeration process because of ++ * the code is added for debugging purposes. ++ */ ++ if (!cpu_feature_enabled(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_ZEN3) || ++ cpu_feature_enabled(X86_FEATURE_ZEN4)) { ++ if ((c->x86_model > 0x10 && c->x86_model < 0x1F) || ++ (c->x86_model > 0x40 && c->x86_model < 0xaf)) ++ warn = true; ++ } else if (cpu_feature_enabled(X86_FEATURE_ZEN5)) { ++ warn = true; ++ } ++ } ++ ++ if (warn) ++ pr_warn_once("The CPPC feature is supported but currently disabled by the BIOS.\n" ++ "Please enable it if your BIOS has the CPPC option.\n"); ++ return true; ++} ++ + static int __init amd_pstate_init(void) + { + struct device *dev_root; +@@ -1750,6 +1898,11 @@ static int __init amd_pstate_init(void) + if (boot_cpu_data.x86_vendor != X86_VENDOR_AMD) + return -ENODEV; + ++ /* show debug message only if CPPC is not supported */ ++ if (!amd_cppc_supported()) ++ return -EOPNOTSUPP; ++ ++ /* show warning message when BIOS broken or ACPI disabled */ + if (!acpi_cpc_valid()) { + pr_warn_once("the _CPC object is not present in SBIOS or ACPI disabled\n"); + return -ENODEV; +@@ -1774,11 +1927,9 @@ static int __init amd_pstate_init(void) + /* Disable on the following configs by default: + * 1. Undefined platforms + * 2. Server platforms +- * 3. Shared memory designs + */ + if (amd_pstate_acpi_pm_profile_undefined() || +- amd_pstate_acpi_pm_profile_server() || +- !boot_cpu_has(X86_FEATURE_CPPC)) { ++ amd_pstate_acpi_pm_profile_server()) { + pr_info("driver load is disabled, boot with specific mode to enable this\n"); + return -ENODEV; + } +@@ -1802,7 +1953,7 @@ static int __init amd_pstate_init(void) + } + + /* capability check */ +- if (boot_cpu_has(X86_FEATURE_CPPC)) { ++ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { + pr_debug("AMD CPPC MSR based functionality is supported\n"); + if (cppc_state != AMD_PSTATE_ACTIVE) + current_pstate_driver->adjust_perf = amd_pstate_adjust_perf; +@@ -1821,8 +1972,10 @@ static int __init amd_pstate_init(void) + } + + ret = cpufreq_register_driver(current_pstate_driver); +- if (ret) ++ if (ret) { + pr_err("failed to register with return %d\n", ret); ++ goto disable_driver; ++ } + + dev_root = bus_get_dev_root(&cpu_subsys); + if (dev_root) { +@@ -1838,6 +1991,8 @@ static int __init amd_pstate_init(void) + + global_attr_free: + cpufreq_unregister_driver(current_pstate_driver); ++disable_driver: ++ amd_pstate_enable(false); + return ret; + } + device_initcall(amd_pstate_init); +diff --git a/drivers/cpufreq/amd-pstate.h b/drivers/cpufreq/amd-pstate.h +index f80b33fa5d43..cc8bb2bc325a 100644 +--- a/drivers/cpufreq/amd-pstate.h ++++ b/drivers/cpufreq/amd-pstate.h +@@ -100,6 +100,7 @@ struct amd_cpudata { + u64 cppc_cap1_cached; + bool suspended; + s16 epp_default; ++ bool boost_state; + }; + + #endif /* _LINUX_AMD_PSTATE_H */ +diff --git a/drivers/cpufreq/cpufreq.c b/drivers/cpufreq/cpufreq.c +index 9e5060b27864..270ea04fb616 100644 +--- a/drivers/cpufreq/cpufreq.c ++++ b/drivers/cpufreq/cpufreq.c +@@ -614,10 +614,9 @@ static ssize_t show_boost(struct kobject *kobj, + static ssize_t store_boost(struct kobject *kobj, struct kobj_attribute *attr, + const char *buf, size_t count) + { +- int ret, enable; ++ bool enable; + +- ret = sscanf(buf, "%d", &enable); +- if (ret != 1 || enable < 0 || enable > 1) ++ if (kstrtobool(buf, &enable)) + return -EINVAL; + + if (cpufreq_boost_trigger_state(enable)) { +@@ -641,10 +640,10 @@ static ssize_t show_local_boost(struct cpufreq_policy *policy, char *buf) + static ssize_t store_local_boost(struct cpufreq_policy *policy, + const char *buf, size_t count) + { +- int ret, enable; ++ int ret; ++ bool enable; + +- ret = kstrtoint(buf, 10, &enable); +- if (ret || enable < 0 || enable > 1) ++ if (kstrtobool(buf, &enable)) + return -EINVAL; + + if (!cpufreq_driver->boost_enabled) +-- +2.46.0 + +From 3f0f8331e3fd1ad093220484546330eef1ed36b9 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:18:40 +0200 +Subject: [PATCH 02/12] bbr3 + +Signed-off-by: Peter Jung +--- + include/linux/tcp.h | 4 +- + include/net/inet_connection_sock.h | 4 +- + include/net/tcp.h | 72 +- + include/uapi/linux/inet_diag.h | 23 + + include/uapi/linux/rtnetlink.h | 4 +- + include/uapi/linux/tcp.h | 1 + + net/ipv4/Kconfig | 21 +- + net/ipv4/bpf_tcp_ca.c | 9 +- + net/ipv4/tcp.c | 3 + + net/ipv4/tcp_bbr.c | 2230 +++++++++++++++++++++------- + net/ipv4/tcp_cong.c | 1 + + net/ipv4/tcp_input.c | 40 +- + net/ipv4/tcp_minisocks.c | 2 + + net/ipv4/tcp_output.c | 48 +- + net/ipv4/tcp_rate.c | 30 +- + net/ipv4/tcp_timer.c | 1 + + 16 files changed, 1940 insertions(+), 553 deletions(-) + +diff --git a/include/linux/tcp.h b/include/linux/tcp.h +index 6a5e08b937b3..27aab715490e 100644 +--- a/include/linux/tcp.h ++++ b/include/linux/tcp.h +@@ -369,7 +369,9 @@ struct tcp_sock { + u8 compressed_ack; + u8 dup_ack_counter:2, + tlp_retrans:1, /* TLP is a retransmission */ +- unused:5; ++ fast_ack_mode:2, /* which fast ack mode ? */ ++ tlp_orig_data_app_limited:1, /* app-limited before TLP rtx? */ ++ unused:2; + u8 thin_lto : 1,/* Use linear timeouts for thin streams */ + fastopen_connect:1, /* FASTOPEN_CONNECT sockopt */ + fastopen_no_cookie:1, /* Allow send/recv SYN+data without a cookie */ +diff --git a/include/net/inet_connection_sock.h b/include/net/inet_connection_sock.h +index c0deaafebfdc..d53f042d936e 100644 +--- a/include/net/inet_connection_sock.h ++++ b/include/net/inet_connection_sock.h +@@ -137,8 +137,8 @@ struct inet_connection_sock { + u32 icsk_probes_tstamp; + u32 icsk_user_timeout; + +- u64 icsk_ca_priv[104 / sizeof(u64)]; +-#define ICSK_CA_PRIV_SIZE sizeof_field(struct inet_connection_sock, icsk_ca_priv) ++#define ICSK_CA_PRIV_SIZE (144) ++ u64 icsk_ca_priv[ICSK_CA_PRIV_SIZE / sizeof(u64)]; + }; + + #define ICSK_TIME_RETRANS 1 /* Retransmit timer */ +diff --git a/include/net/tcp.h b/include/net/tcp.h +index 32815a40dea1..109b8d1ddc31 100644 +--- a/include/net/tcp.h ++++ b/include/net/tcp.h +@@ -375,6 +375,8 @@ static inline void tcp_dec_quickack_mode(struct sock *sk) + #define TCP_ECN_QUEUE_CWR 2 + #define TCP_ECN_DEMAND_CWR 4 + #define TCP_ECN_SEEN 8 ++#define TCP_ECN_LOW 16 ++#define TCP_ECN_ECT_PERMANENT 32 + + enum tcp_tw_status { + TCP_TW_SUCCESS = 0, +@@ -779,6 +781,15 @@ static inline void tcp_fast_path_check(struct sock *sk) + + u32 tcp_delack_max(const struct sock *sk); + ++static inline void tcp_set_ecn_low_from_dst(struct sock *sk, ++ const struct dst_entry *dst) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ ++ if (dst_feature(dst, RTAX_FEATURE_ECN_LOW)) ++ tp->ecn_flags |= TCP_ECN_LOW; ++} ++ + /* Compute the actual rto_min value */ + static inline u32 tcp_rto_min(const struct sock *sk) + { +@@ -884,6 +895,11 @@ static inline u32 tcp_stamp_us_delta(u64 t1, u64 t0) + return max_t(s64, t1 - t0, 0); + } + ++static inline u32 tcp_stamp32_us_delta(u32 t1, u32 t0) ++{ ++ return max_t(s32, t1 - t0, 0); ++} ++ + /* provide the departure time in us unit */ + static inline u64 tcp_skb_timestamp_us(const struct sk_buff *skb) + { +@@ -973,9 +989,14 @@ struct tcp_skb_cb { + /* pkts S/ACKed so far upon tx of skb, incl retrans: */ + __u32 delivered; + /* start of send pipeline phase */ +- u64 first_tx_mstamp; ++ u32 first_tx_mstamp; + /* when we reached the "delivered" count */ +- u64 delivered_mstamp; ++ u32 delivered_mstamp; ++#define TCPCB_IN_FLIGHT_BITS 20 ++#define TCPCB_IN_FLIGHT_MAX ((1U << TCPCB_IN_FLIGHT_BITS) - 1) ++ u32 in_flight:20, /* packets in flight at transmit */ ++ unused2:12; ++ u32 lost; /* packets lost so far upon tx of skb */ + } tx; /* only used for outgoing skbs */ + union { + struct inet_skb_parm h4; +@@ -1079,6 +1100,7 @@ enum tcp_ca_event { + CA_EVENT_LOSS, /* loss timeout */ + CA_EVENT_ECN_NO_CE, /* ECT set, but not CE marked */ + CA_EVENT_ECN_IS_CE, /* received CE marked IP packet */ ++ CA_EVENT_TLP_RECOVERY, /* a lost segment was repaired by TLP probe */ + }; + + /* Information about inbound ACK, passed to cong_ops->in_ack_event() */ +@@ -1101,7 +1123,11 @@ enum tcp_ca_ack_event_flags { + #define TCP_CONG_NON_RESTRICTED 0x1 + /* Requires ECN/ECT set on all packets */ + #define TCP_CONG_NEEDS_ECN 0x2 +-#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | TCP_CONG_NEEDS_ECN) ++/* Wants notification of CE events (CA_EVENT_ECN_IS_CE, CA_EVENT_ECN_NO_CE). */ ++#define TCP_CONG_WANTS_CE_EVENTS 0x4 ++#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | \ ++ TCP_CONG_NEEDS_ECN | \ ++ TCP_CONG_WANTS_CE_EVENTS) + + union tcp_cc_info; + +@@ -1121,10 +1147,13 @@ struct ack_sample { + */ + struct rate_sample { + u64 prior_mstamp; /* starting timestamp for interval */ ++ u32 prior_lost; /* tp->lost at "prior_mstamp" */ + u32 prior_delivered; /* tp->delivered at "prior_mstamp" */ + u32 prior_delivered_ce;/* tp->delivered_ce at "prior_mstamp" */ ++ u32 tx_in_flight; /* packets in flight at starting timestamp */ ++ s32 lost; /* number of packets lost over interval */ + s32 delivered; /* number of packets delivered over interval */ +- s32 delivered_ce; /* number of packets delivered w/ CE marks*/ ++ s32 delivered_ce; /* packets delivered w/ CE mark over interval */ + long interval_us; /* time for tp->delivered to incr "delivered" */ + u32 snd_interval_us; /* snd interval for delivered packets */ + u32 rcv_interval_us; /* rcv interval for delivered packets */ +@@ -1135,7 +1164,9 @@ struct rate_sample { + u32 last_end_seq; /* end_seq of most recently ACKed packet */ + bool is_app_limited; /* is sample from packet with bubble in pipe? */ + bool is_retrans; /* is sample from retransmission? */ ++ bool is_acking_tlp_retrans_seq; /* ACKed a TLP retransmit sequence? */ + bool is_ack_delayed; /* is this (likely) a delayed ACK? */ ++ bool is_ece; /* did this ACK have ECN marked? */ + }; + + struct tcp_congestion_ops { +@@ -1159,8 +1190,11 @@ struct tcp_congestion_ops { + /* hook for packet ack accounting (optional) */ + void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample); + +- /* override sysctl_tcp_min_tso_segs */ +- u32 (*min_tso_segs)(struct sock *sk); ++ /* pick target number of segments per TSO/GSO skb (optional): */ ++ u32 (*tso_segs)(struct sock *sk, unsigned int mss_now); ++ ++ /* react to a specific lost skb (optional) */ ++ void (*skb_marked_lost)(struct sock *sk, const struct sk_buff *skb); + + /* call when packets are delivered to update cwnd and pacing rate, + * after all the ca_state processing. (optional) +@@ -1226,6 +1260,14 @@ static inline char *tcp_ca_get_name_by_key(u32 key, char *buffer) + } + #endif + ++static inline bool tcp_ca_wants_ce_events(const struct sock *sk) ++{ ++ const struct inet_connection_sock *icsk = inet_csk(sk); ++ ++ return icsk->icsk_ca_ops->flags & (TCP_CONG_NEEDS_ECN | ++ TCP_CONG_WANTS_CE_EVENTS); ++} ++ + static inline bool tcp_ca_needs_ecn(const struct sock *sk) + { + const struct inet_connection_sock *icsk = inet_csk(sk); +@@ -1245,6 +1287,7 @@ static inline void tcp_ca_event(struct sock *sk, const enum tcp_ca_event event) + void tcp_set_ca_state(struct sock *sk, const u8 ca_state); + + /* From tcp_rate.c */ ++void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb); + void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb); + void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, + struct rate_sample *rs); +@@ -1257,6 +1300,21 @@ static inline bool tcp_skb_sent_after(u64 t1, u64 t2, u32 seq1, u32 seq2) + return t1 > t2 || (t1 == t2 && after(seq1, seq2)); + } + ++/* If a retransmit failed due to local qdisc congestion or other local issues, ++ * then we may have called tcp_set_skb_tso_segs() to increase the number of ++ * segments in the skb without increasing the tx.in_flight. In all other cases, ++ * the tx.in_flight should be at least as big as the pcount of the sk_buff. We ++ * do not have the state to know whether a retransmit failed due to local qdisc ++ * congestion or other local issues, so to avoid spurious warnings we consider ++ * that any skb marked lost may have suffered that fate. ++ */ ++static inline bool tcp_skb_tx_in_flight_is_suspicious(u32 skb_pcount, ++ u32 skb_sacked_flags, ++ u32 tx_in_flight) ++{ ++ return (skb_pcount > tx_in_flight) && !(skb_sacked_flags & TCPCB_LOST); ++} ++ + /* These functions determine how the current flow behaves in respect of SACK + * handling. SACK is negotiated with the peer, and therefore it can vary + * between different flows. +@@ -2419,7 +2477,7 @@ struct tcp_plb_state { + u8 consec_cong_rounds:5, /* consecutive congested rounds */ + unused:3; + u32 pause_until; /* jiffies32 when PLB can resume rerouting */ +-}; ++} __attribute__ ((__packed__)); + + static inline void tcp_plb_init(const struct sock *sk, + struct tcp_plb_state *plb) +diff --git a/include/uapi/linux/inet_diag.h b/include/uapi/linux/inet_diag.h +index 50655de04c9b..82f8bd8f0d16 100644 +--- a/include/uapi/linux/inet_diag.h ++++ b/include/uapi/linux/inet_diag.h +@@ -229,6 +229,29 @@ struct tcp_bbr_info { + __u32 bbr_min_rtt; /* min-filtered RTT in uSec */ + __u32 bbr_pacing_gain; /* pacing gain shifted left 8 bits */ + __u32 bbr_cwnd_gain; /* cwnd gain shifted left 8 bits */ ++ __u32 bbr_bw_hi_lsb; /* lower 32 bits of bw_hi */ ++ __u32 bbr_bw_hi_msb; /* upper 32 bits of bw_hi */ ++ __u32 bbr_bw_lo_lsb; /* lower 32 bits of bw_lo */ ++ __u32 bbr_bw_lo_msb; /* upper 32 bits of bw_lo */ ++ __u8 bbr_mode; /* current bbr_mode in state machine */ ++ __u8 bbr_phase; /* current state machine phase */ ++ __u8 unused1; /* alignment padding; not used yet */ ++ __u8 bbr_version; /* BBR algorithm version */ ++ __u32 bbr_inflight_lo; /* lower short-term data volume bound */ ++ __u32 bbr_inflight_hi; /* higher long-term data volume bound */ ++ __u32 bbr_extra_acked; /* max excess packets ACKed in epoch */ ++}; ++ ++/* TCP BBR congestion control bbr_phase as reported in netlink/ss stats. */ ++enum tcp_bbr_phase { ++ BBR_PHASE_INVALID = 0, ++ BBR_PHASE_STARTUP = 1, ++ BBR_PHASE_DRAIN = 2, ++ BBR_PHASE_PROBE_RTT = 3, ++ BBR_PHASE_PROBE_BW_UP = 4, ++ BBR_PHASE_PROBE_BW_DOWN = 5, ++ BBR_PHASE_PROBE_BW_CRUISE = 6, ++ BBR_PHASE_PROBE_BW_REFILL = 7, + }; + + union tcp_cc_info { +diff --git a/include/uapi/linux/rtnetlink.h b/include/uapi/linux/rtnetlink.h +index 3b687d20c9ed..a7c30c243b54 100644 +--- a/include/uapi/linux/rtnetlink.h ++++ b/include/uapi/linux/rtnetlink.h +@@ -507,12 +507,14 @@ enum { + #define RTAX_FEATURE_TIMESTAMP (1 << 2) /* unused */ + #define RTAX_FEATURE_ALLFRAG (1 << 3) /* unused */ + #define RTAX_FEATURE_TCP_USEC_TS (1 << 4) ++#define RTAX_FEATURE_ECN_LOW (1 << 5) + + #define RTAX_FEATURE_MASK (RTAX_FEATURE_ECN | \ + RTAX_FEATURE_SACK | \ + RTAX_FEATURE_TIMESTAMP | \ + RTAX_FEATURE_ALLFRAG | \ +- RTAX_FEATURE_TCP_USEC_TS) ++ RTAX_FEATURE_TCP_USEC_TS | \ ++ RTAX_FEATURE_ECN_LOW) + + struct rta_session { + __u8 proto; +diff --git a/include/uapi/linux/tcp.h b/include/uapi/linux/tcp.h +index dbf896f3146c..4702cd2f1ffc 100644 +--- a/include/uapi/linux/tcp.h ++++ b/include/uapi/linux/tcp.h +@@ -178,6 +178,7 @@ enum tcp_fastopen_client_fail { + #define TCPI_OPT_ECN_SEEN 16 /* we received at least one packet with ECT */ + #define TCPI_OPT_SYN_DATA 32 /* SYN-ACK acked data in SYN sent or rcvd */ + #define TCPI_OPT_USEC_TS 64 /* usec timestamps */ ++#define TCPI_OPT_ECN_LOW 128 /* Low-latency ECN configured at init */ + + /* + * Sender's congestion state indicating normal or abnormal situations +diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig +index 8e94ed7c56a0..50dc9970cad2 100644 +--- a/net/ipv4/Kconfig ++++ b/net/ipv4/Kconfig +@@ -668,15 +668,18 @@ config TCP_CONG_BBR + default n + help + +- BBR (Bottleneck Bandwidth and RTT) TCP congestion control aims to +- maximize network utilization and minimize queues. It builds an explicit +- model of the bottleneck delivery rate and path round-trip propagation +- delay. It tolerates packet loss and delay unrelated to congestion. It +- can operate over LAN, WAN, cellular, wifi, or cable modem links. It can +- coexist with flows that use loss-based congestion control, and can +- operate with shallow buffers, deep buffers, bufferbloat, policers, or +- AQM schemes that do not provide a delay signal. It requires the fq +- ("Fair Queue") pacing packet scheduler. ++ BBR (Bottleneck Bandwidth and RTT) TCP congestion control is a ++ model-based congestion control algorithm that aims to maximize ++ network utilization, keep queues and retransmit rates low, and to be ++ able to coexist with Reno/CUBIC in common scenarios. It builds an ++ explicit model of the network path. It tolerates a targeted degree ++ of random packet loss and delay. It can operate over LAN, WAN, ++ cellular, wifi, or cable modem links, and can use shallow-threshold ++ ECN signals. It can coexist to some degree with flows that use ++ loss-based congestion control, and can operate with shallow buffers, ++ deep buffers, bufferbloat, policers, or AQM schemes that do not ++ provide a delay signal. It requires pacing, using either TCP internal ++ pacing or the fq ("Fair Queue") pacing packet scheduler. + + choice + prompt "Default TCP congestion control" +diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c +index 18227757ec0c..f180befc28bd 100644 +--- a/net/ipv4/bpf_tcp_ca.c ++++ b/net/ipv4/bpf_tcp_ca.c +@@ -305,11 +305,15 @@ static void bpf_tcp_ca_pkts_acked(struct sock *sk, const struct ack_sample *samp + { + } + +-static u32 bpf_tcp_ca_min_tso_segs(struct sock *sk) ++static u32 bpf_tcp_ca_tso_segs(struct sock *sk, unsigned int mss_now) + { + return 0; + } + ++static void bpf_tcp_ca_skb_marked_lost(struct sock *sk, const struct sk_buff *skb) ++{ ++} ++ + static void bpf_tcp_ca_cong_control(struct sock *sk, u32 ack, int flag, + const struct rate_sample *rs) + { +@@ -340,7 +344,8 @@ static struct tcp_congestion_ops __bpf_ops_tcp_congestion_ops = { + .cwnd_event = bpf_tcp_ca_cwnd_event, + .in_ack_event = bpf_tcp_ca_in_ack_event, + .pkts_acked = bpf_tcp_ca_pkts_acked, +- .min_tso_segs = bpf_tcp_ca_min_tso_segs, ++ .tso_segs = bpf_tcp_ca_tso_segs, ++ .skb_marked_lost = bpf_tcp_ca_skb_marked_lost, + .cong_control = bpf_tcp_ca_cong_control, + .undo_cwnd = bpf_tcp_ca_undo_cwnd, + .sndbuf_expand = bpf_tcp_ca_sndbuf_expand, +diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c +index ec6911034138..df7731a30198 100644 +--- a/net/ipv4/tcp.c ++++ b/net/ipv4/tcp.c +@@ -3120,6 +3120,7 @@ int tcp_disconnect(struct sock *sk, int flags) + tp->rx_opt.dsack = 0; + tp->rx_opt.num_sacks = 0; + tp->rcv_ooopack = 0; ++ tp->fast_ack_mode = 0; + + + /* Clean up fastopen related fields */ +@@ -3846,6 +3847,8 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) + info->tcpi_options |= TCPI_OPT_ECN; + if (tp->ecn_flags & TCP_ECN_SEEN) + info->tcpi_options |= TCPI_OPT_ECN_SEEN; ++ if (tp->ecn_flags & TCP_ECN_LOW) ++ info->tcpi_options |= TCPI_OPT_ECN_LOW; + if (tp->syn_data_acked) + info->tcpi_options |= TCPI_OPT_SYN_DATA; + if (tp->tcp_usec_ts) +diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c +index 760941e55153..a180fa648d5e 100644 +--- a/net/ipv4/tcp_bbr.c ++++ b/net/ipv4/tcp_bbr.c +@@ -1,18 +1,19 @@ +-/* Bottleneck Bandwidth and RTT (BBR) congestion control ++/* BBR (Bottleneck Bandwidth and RTT) congestion control + * +- * BBR congestion control computes the sending rate based on the delivery +- * rate (throughput) estimated from ACKs. In a nutshell: ++ * BBR is a model-based congestion control algorithm that aims for low queues, ++ * low loss, and (bounded) Reno/CUBIC coexistence. To maintain a model of the ++ * network path, it uses measurements of bandwidth and RTT, as well as (if they ++ * occur) packet loss and/or shallow-threshold ECN signals. Note that although ++ * it can use ECN or loss signals explicitly, it does not require either; it ++ * can bound its in-flight data based on its estimate of the BDP. + * +- * On each ACK, update our model of the network path: +- * bottleneck_bandwidth = windowed_max(delivered / elapsed, 10 round trips) +- * min_rtt = windowed_min(rtt, 10 seconds) +- * pacing_rate = pacing_gain * bottleneck_bandwidth +- * cwnd = max(cwnd_gain * bottleneck_bandwidth * min_rtt, 4) +- * +- * The core algorithm does not react directly to packet losses or delays, +- * although BBR may adjust the size of next send per ACK when loss is +- * observed, or adjust the sending rate if it estimates there is a +- * traffic policer, in order to keep the drop rate reasonable. ++ * The model has both higher and lower bounds for the operating range: ++ * lo: bw_lo, inflight_lo: conservative short-term lower bound ++ * hi: bw_hi, inflight_hi: robust long-term upper bound ++ * The bandwidth-probing time scale is (a) extended dynamically based on ++ * estimated BDP to improve coexistence with Reno/CUBIC; (b) bounded by ++ * an interactive wall-clock time-scale to be more scalable and responsive ++ * than Reno and CUBIC. + * + * Here is a state transition diagram for BBR: + * +@@ -65,6 +66,13 @@ + #include + #include + ++#include ++#include "tcp_dctcp.h" ++ ++#define BBR_VERSION 3 ++ ++#define bbr_param(sk,name) (bbr_ ## name) ++ + /* Scale factor for rate in pkt/uSec unit to avoid truncation in bandwidth + * estimation. The rate unit ~= (1500 bytes / 1 usec / 2^24) ~= 715 bps. + * This handles bandwidths from 0.06pps (715bps) to 256Mpps (3Tbps) in a u32. +@@ -85,36 +93,41 @@ enum bbr_mode { + BBR_PROBE_RTT, /* cut inflight to min to probe min_rtt */ + }; + ++/* How does the incoming ACK stream relate to our bandwidth probing? */ ++enum bbr_ack_phase { ++ BBR_ACKS_INIT, /* not probing; not getting probe feedback */ ++ BBR_ACKS_REFILLING, /* sending at est. bw to fill pipe */ ++ BBR_ACKS_PROBE_STARTING, /* inflight rising to probe bw */ ++ BBR_ACKS_PROBE_FEEDBACK, /* getting feedback from bw probing */ ++ BBR_ACKS_PROBE_STOPPING, /* stopped probing; still getting feedback */ ++}; ++ + /* BBR congestion control block */ + struct bbr { + u32 min_rtt_us; /* min RTT in min_rtt_win_sec window */ + u32 min_rtt_stamp; /* timestamp of min_rtt_us */ + u32 probe_rtt_done_stamp; /* end time for BBR_PROBE_RTT mode */ +- struct minmax bw; /* Max recent delivery rate in pkts/uS << 24 */ +- u32 rtt_cnt; /* count of packet-timed rounds elapsed */ ++ u32 probe_rtt_min_us; /* min RTT in probe_rtt_win_ms win */ ++ u32 probe_rtt_min_stamp; /* timestamp of probe_rtt_min_us*/ + u32 next_rtt_delivered; /* scb->tx.delivered at end of round */ + u64 cycle_mstamp; /* time of this cycle phase start */ +- u32 mode:3, /* current bbr_mode in state machine */ ++ u32 mode:2, /* current bbr_mode in state machine */ + prev_ca_state:3, /* CA state on previous ACK */ +- packet_conservation:1, /* use packet conservation? */ + round_start:1, /* start of packet-timed tx->ack round? */ ++ ce_state:1, /* If most recent data has CE bit set */ ++ bw_probe_up_rounds:5, /* cwnd-limited rounds in PROBE_UP */ ++ try_fast_path:1, /* can we take fast path? */ + idle_restart:1, /* restarting after idle? */ + probe_rtt_round_done:1, /* a BBR_PROBE_RTT round at 4 pkts? */ +- unused:13, +- lt_is_sampling:1, /* taking long-term ("LT") samples now? */ +- lt_rtt_cnt:7, /* round trips in long-term interval */ +- lt_use_bw:1; /* use lt_bw as our bw estimate? */ +- u32 lt_bw; /* LT est delivery rate in pkts/uS << 24 */ +- u32 lt_last_delivered; /* LT intvl start: tp->delivered */ +- u32 lt_last_stamp; /* LT intvl start: tp->delivered_mstamp */ +- u32 lt_last_lost; /* LT intvl start: tp->lost */ ++ init_cwnd:7, /* initial cwnd */ ++ unused_1:10; + u32 pacing_gain:10, /* current gain for setting pacing rate */ + cwnd_gain:10, /* current gain for setting cwnd */ + full_bw_reached:1, /* reached full bw in Startup? */ + full_bw_cnt:2, /* number of rounds without large bw gains */ +- cycle_idx:3, /* current index in pacing_gain cycle array */ ++ cycle_idx:2, /* current index in pacing_gain cycle array */ + has_seen_rtt:1, /* have we seen an RTT sample yet? */ +- unused_b:5; ++ unused_2:6; + u32 prior_cwnd; /* prior cwnd upon entering loss recovery */ + u32 full_bw; /* recent bw, to estimate if pipe is full */ + +@@ -124,19 +137,67 @@ struct bbr { + u32 ack_epoch_acked:20, /* packets (S)ACKed in sampling epoch */ + extra_acked_win_rtts:5, /* age of extra_acked, in round trips */ + extra_acked_win_idx:1, /* current index in extra_acked array */ +- unused_c:6; ++ /* BBR v3 state: */ ++ full_bw_now:1, /* recently reached full bw plateau? */ ++ startup_ecn_rounds:2, /* consecutive hi ECN STARTUP rounds */ ++ loss_in_cycle:1, /* packet loss in this cycle? */ ++ ecn_in_cycle:1, /* ECN in this cycle? */ ++ unused_3:1; ++ u32 loss_round_delivered; /* scb->tx.delivered ending loss round */ ++ u32 undo_bw_lo; /* bw_lo before latest losses */ ++ u32 undo_inflight_lo; /* inflight_lo before latest losses */ ++ u32 undo_inflight_hi; /* inflight_hi before latest losses */ ++ u32 bw_latest; /* max delivered bw in last round trip */ ++ u32 bw_lo; /* lower bound on sending bandwidth */ ++ u32 bw_hi[2]; /* max recent measured bw sample */ ++ u32 inflight_latest; /* max delivered data in last round trip */ ++ u32 inflight_lo; /* lower bound of inflight data range */ ++ u32 inflight_hi; /* upper bound of inflight data range */ ++ u32 bw_probe_up_cnt; /* packets delivered per inflight_hi incr */ ++ u32 bw_probe_up_acks; /* packets (S)ACKed since inflight_hi incr */ ++ u32 probe_wait_us; /* PROBE_DOWN until next clock-driven probe */ ++ u32 prior_rcv_nxt; /* tp->rcv_nxt when CE state last changed */ ++ u32 ecn_eligible:1, /* sender can use ECN (RTT, handshake)? */ ++ ecn_alpha:9, /* EWMA delivered_ce/delivered; 0..256 */ ++ bw_probe_samples:1, /* rate samples reflect bw probing? */ ++ prev_probe_too_high:1, /* did last PROBE_UP go too high? */ ++ stopped_risky_probe:1, /* last PROBE_UP stopped due to risk? */ ++ rounds_since_probe:8, /* packet-timed rounds since probed bw */ ++ loss_round_start:1, /* loss_round_delivered round trip? */ ++ loss_in_round:1, /* loss marked in this round trip? */ ++ ecn_in_round:1, /* ECN marked in this round trip? */ ++ ack_phase:3, /* bbr_ack_phase: meaning of ACKs */ ++ loss_events_in_round:4,/* losses in STARTUP round */ ++ initialized:1; /* has bbr_init() been called? */ ++ u32 alpha_last_delivered; /* tp->delivered at alpha update */ ++ u32 alpha_last_delivered_ce; /* tp->delivered_ce at alpha update */ ++ ++ u8 unused_4; /* to preserve alignment */ ++ struct tcp_plb_state plb; + }; + +-#define CYCLE_LEN 8 /* number of phases in a pacing gain cycle */ ++struct bbr_context { ++ u32 sample_bw; ++}; + +-/* Window length of bw filter (in rounds): */ +-static const int bbr_bw_rtts = CYCLE_LEN + 2; + /* Window length of min_rtt filter (in sec): */ + static const u32 bbr_min_rtt_win_sec = 10; + /* Minimum time (in ms) spent at bbr_cwnd_min_target in BBR_PROBE_RTT mode: */ + static const u32 bbr_probe_rtt_mode_ms = 200; +-/* Skip TSO below the following bandwidth (bits/sec): */ +-static const int bbr_min_tso_rate = 1200000; ++/* Window length of probe_rtt_min_us filter (in ms), and consequently the ++ * typical interval between PROBE_RTT mode entries. The default is 5000ms. ++ * Note that bbr_probe_rtt_win_ms must be <= bbr_min_rtt_win_sec * MSEC_PER_SEC ++ */ ++static const u32 bbr_probe_rtt_win_ms = 5000; ++/* Proportion of cwnd to estimated BDP in PROBE_RTT, in units of BBR_UNIT: */ ++static const u32 bbr_probe_rtt_cwnd_gain = BBR_UNIT * 1 / 2; ++ ++/* Use min_rtt to help adapt TSO burst size, with smaller min_rtt resulting ++ * in bigger TSO bursts. We cut the RTT-based allowance in half ++ * for every 2^9 usec (aka 512 us) of RTT, so that the RTT-based allowance ++ * is below 1500 bytes after 6 * ~500 usec = 3ms. ++ */ ++static const u32 bbr_tso_rtt_shift = 9; + + /* Pace at ~1% below estimated bw, on average, to reduce queue at bottleneck. + * In order to help drive the network toward lower queues and low latency while +@@ -146,13 +207,15 @@ static const int bbr_min_tso_rate = 1200000; + */ + static const int bbr_pacing_margin_percent = 1; + +-/* We use a high_gain value of 2/ln(2) because it's the smallest pacing gain ++/* We use a startup_pacing_gain of 4*ln(2) because it's the smallest value + * that will allow a smoothly increasing pacing rate that will double each RTT + * and send the same number of packets per RTT that an un-paced, slow-starting + * Reno or CUBIC flow would: + */ +-static const int bbr_high_gain = BBR_UNIT * 2885 / 1000 + 1; +-/* The pacing gain of 1/high_gain in BBR_DRAIN is calculated to typically drain ++static const int bbr_startup_pacing_gain = BBR_UNIT * 277 / 100 + 1; ++/* The gain for deriving startup cwnd: */ ++static const int bbr_startup_cwnd_gain = BBR_UNIT * 2; ++/* The pacing gain in BBR_DRAIN is calculated to typically drain + * the queue created in BBR_STARTUP in a single round: + */ + static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; +@@ -160,13 +223,17 @@ static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; + static const int bbr_cwnd_gain = BBR_UNIT * 2; + /* The pacing_gain values for the PROBE_BW gain cycle, to discover/share bw: */ + static const int bbr_pacing_gain[] = { +- BBR_UNIT * 5 / 4, /* probe for more available bw */ +- BBR_UNIT * 3 / 4, /* drain queue and/or yield bw to other flows */ +- BBR_UNIT, BBR_UNIT, BBR_UNIT, /* cruise at 1.0*bw to utilize pipe, */ +- BBR_UNIT, BBR_UNIT, BBR_UNIT /* without creating excess queue... */ ++ BBR_UNIT * 5 / 4, /* UP: probe for more available bw */ ++ BBR_UNIT * 91 / 100, /* DOWN: drain queue and/or yield bw */ ++ BBR_UNIT, /* CRUISE: try to use pipe w/ some headroom */ ++ BBR_UNIT, /* REFILL: refill pipe to estimated 100% */ ++}; ++enum bbr_pacing_gain_phase { ++ BBR_BW_PROBE_UP = 0, /* push up inflight to probe for bw/vol */ ++ BBR_BW_PROBE_DOWN = 1, /* drain excess inflight from the queue */ ++ BBR_BW_PROBE_CRUISE = 2, /* use pipe, w/ headroom in queue/pipe */ ++ BBR_BW_PROBE_REFILL = 3, /* v2: refill the pipe again to 100% */ + }; +-/* Randomize the starting gain cycling phase over N phases: */ +-static const u32 bbr_cycle_rand = 7; + + /* Try to keep at least this many packets in flight, if things go smoothly. For + * smooth functioning, a sliding window protocol ACKing every other packet +@@ -174,24 +241,12 @@ static const u32 bbr_cycle_rand = 7; + */ + static const u32 bbr_cwnd_min_target = 4; + +-/* To estimate if BBR_STARTUP mode (i.e. high_gain) has filled pipe... */ ++/* To estimate if BBR_STARTUP or BBR_BW_PROBE_UP has filled pipe... */ + /* If bw has increased significantly (1.25x), there may be more bw available: */ + static const u32 bbr_full_bw_thresh = BBR_UNIT * 5 / 4; + /* But after 3 rounds w/o significant bw growth, estimate pipe is full: */ + static const u32 bbr_full_bw_cnt = 3; + +-/* "long-term" ("LT") bandwidth estimator parameters... */ +-/* The minimum number of rounds in an LT bw sampling interval: */ +-static const u32 bbr_lt_intvl_min_rtts = 4; +-/* If lost/delivered ratio > 20%, interval is "lossy" and we may be policed: */ +-static const u32 bbr_lt_loss_thresh = 50; +-/* If 2 intervals have a bw ratio <= 1/8, their bw is "consistent": */ +-static const u32 bbr_lt_bw_ratio = BBR_UNIT / 8; +-/* If 2 intervals have a bw diff <= 4 Kbit/sec their bw is "consistent": */ +-static const u32 bbr_lt_bw_diff = 4000 / 8; +-/* If we estimate we're policed, use lt_bw for this many round trips: */ +-static const u32 bbr_lt_bw_max_rtts = 48; +- + /* Gain factor for adding extra_acked to target cwnd: */ + static const int bbr_extra_acked_gain = BBR_UNIT; + /* Window length of extra_acked window. */ +@@ -201,8 +256,121 @@ static const u32 bbr_ack_epoch_acked_reset_thresh = 1U << 20; + /* Time period for clamping cwnd increment due to ack aggregation */ + static const u32 bbr_extra_acked_max_us = 100 * 1000; + ++/* Flags to control BBR ECN-related behavior... */ ++ ++/* Ensure ACKs only ACK packets with consistent ECN CE status? */ ++static const bool bbr_precise_ece_ack = true; ++ ++/* Max RTT (in usec) at which to use sender-side ECN logic. ++ * Disabled when 0 (ECN allowed at any RTT). ++ */ ++static const u32 bbr_ecn_max_rtt_us = 5000; ++ ++/* On losses, scale down inflight and pacing rate by beta scaled by BBR_SCALE. ++ * No loss response when 0. ++ */ ++static const u32 bbr_beta = BBR_UNIT * 30 / 100; ++ ++/* Gain factor for ECN mark ratio samples, scaled by BBR_SCALE (1/16 = 6.25%) */ ++static const u32 bbr_ecn_alpha_gain = BBR_UNIT * 1 / 16; ++ ++/* The initial value for ecn_alpha; 1.0 allows a flow to respond quickly ++ * to congestion if the bottleneck is congested when the flow starts up. ++ */ ++static const u32 bbr_ecn_alpha_init = BBR_UNIT; ++ ++/* On ECN, cut inflight_lo to (1 - ecn_factor * ecn_alpha) scaled by BBR_SCALE. ++ * No ECN based bounding when 0. ++ */ ++static const u32 bbr_ecn_factor = BBR_UNIT * 1 / 3; /* 1/3 = 33% */ ++ ++/* Estimate bw probing has gone too far if CE ratio exceeds this threshold. ++ * Scaled by BBR_SCALE. Disabled when 0. ++ */ ++static const u32 bbr_ecn_thresh = BBR_UNIT * 1 / 2; /* 1/2 = 50% */ ++ ++/* If non-zero, if in a cycle with no losses but some ECN marks, after ECN ++ * clears then make the first round's increment to inflight_hi the following ++ * fraction of inflight_hi. ++ */ ++static const u32 bbr_ecn_reprobe_gain = BBR_UNIT * 1 / 2; ++ ++/* Estimate bw probing has gone too far if loss rate exceeds this level. */ ++static const u32 bbr_loss_thresh = BBR_UNIT * 2 / 100; /* 2% loss */ ++ ++/* Slow down for a packet loss recovered by TLP? */ ++static const bool bbr_loss_probe_recovery = true; ++ ++/* Exit STARTUP if number of loss marking events in a Recovery round is >= N, ++ * and loss rate is higher than bbr_loss_thresh. ++ * Disabled if 0. ++ */ ++static const u32 bbr_full_loss_cnt = 6; ++ ++/* Exit STARTUP if number of round trips with ECN mark rate above ecn_thresh ++ * meets this count. ++ */ ++static const u32 bbr_full_ecn_cnt = 2; ++ ++/* Fraction of unutilized headroom to try to leave in path upon high loss. */ ++static const u32 bbr_inflight_headroom = BBR_UNIT * 15 / 100; ++ ++/* How much do we increase cwnd_gain when probing for bandwidth in ++ * BBR_BW_PROBE_UP? This specifies the increment in units of ++ * BBR_UNIT/4. The default is 1, meaning 0.25. ++ * The min value is 0 (meaning 0.0); max is 3 (meaning 0.75). ++ */ ++static const u32 bbr_bw_probe_cwnd_gain = 1; ++ ++/* Max number of packet-timed rounds to wait before probing for bandwidth. If ++ * we want to tolerate 1% random loss per round, and not have this cut our ++ * inflight too much, we must probe for bw periodically on roughly this scale. ++ * If low, limits Reno/CUBIC coexistence; if high, limits loss tolerance. ++ * We aim to be fair with Reno/CUBIC up to a BDP of at least: ++ * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets ++ */ ++static const u32 bbr_bw_probe_max_rounds = 63; ++ ++/* Max amount of randomness to inject in round counting for Reno-coexistence. ++ */ ++static const u32 bbr_bw_probe_rand_rounds = 2; ++ ++/* Use BBR-native probe time scale starting at this many usec. ++ * We aim to be fair with Reno/CUBIC up to an inter-loss time epoch of at least: ++ * BDP*RTT = 25Mbps * .030sec /(1514bytes) * 0.030sec = 1.9 secs ++ */ ++static const u32 bbr_bw_probe_base_us = 2 * USEC_PER_SEC; /* 2 secs */ ++ ++/* Use BBR-native probes spread over this many usec: */ ++static const u32 bbr_bw_probe_rand_us = 1 * USEC_PER_SEC; /* 1 secs */ ++ ++/* Use fast path if app-limited, no loss/ECN, and target cwnd was reached? */ ++static const bool bbr_fast_path = true; ++ ++/* Use fast ack mode? */ ++static const bool bbr_fast_ack_mode = true; ++ ++static u32 bbr_max_bw(const struct sock *sk); ++static u32 bbr_bw(const struct sock *sk); ++static void bbr_exit_probe_rtt(struct sock *sk); ++static void bbr_reset_congestion_signals(struct sock *sk); ++static void bbr_run_loss_probe_recovery(struct sock *sk); ++ + static void bbr_check_probe_rtt_done(struct sock *sk); + ++/* This connection can use ECN if both endpoints have signaled ECN support in ++ * the handshake and the per-route settings indicated this is a ++ * shallow-threshold ECN environment, meaning both: ++ * (a) ECN CE marks indicate low-latency/shallow-threshold congestion, and ++ * (b) TCP endpoints provide precise ACKs that only ACK data segments ++ * with consistent ECN CE status ++ */ ++static bool bbr_can_use_ecn(const struct sock *sk) ++{ ++ return (tcp_sk(sk)->ecn_flags & TCP_ECN_OK) && ++ (tcp_sk(sk)->ecn_flags & TCP_ECN_LOW); ++} ++ + /* Do we estimate that STARTUP filled the pipe? */ + static bool bbr_full_bw_reached(const struct sock *sk) + { +@@ -214,17 +382,17 @@ static bool bbr_full_bw_reached(const struct sock *sk) + /* Return the windowed max recent bandwidth sample, in pkts/uS << BW_SCALE. */ + static u32 bbr_max_bw(const struct sock *sk) + { +- struct bbr *bbr = inet_csk_ca(sk); ++ const struct bbr *bbr = inet_csk_ca(sk); + +- return minmax_get(&bbr->bw); ++ return max(bbr->bw_hi[0], bbr->bw_hi[1]); + } + + /* Return the estimated bandwidth of the path, in pkts/uS << BW_SCALE. */ + static u32 bbr_bw(const struct sock *sk) + { +- struct bbr *bbr = inet_csk_ca(sk); ++ const struct bbr *bbr = inet_csk_ca(sk); + +- return bbr->lt_use_bw ? bbr->lt_bw : bbr_max_bw(sk); ++ return min(bbr_max_bw(sk), bbr->bw_lo); + } + + /* Return maximum extra acked in past k-2k round trips, +@@ -241,15 +409,23 @@ static u16 bbr_extra_acked(const struct sock *sk) + * The order here is chosen carefully to avoid overflow of u64. This should + * work for input rates of up to 2.9Tbit/sec and gain of 2.89x. + */ +-static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain) ++static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain, ++ int margin) + { + unsigned int mss = tcp_sk(sk)->mss_cache; + + rate *= mss; + rate *= gain; + rate >>= BBR_SCALE; +- rate *= USEC_PER_SEC / 100 * (100 - bbr_pacing_margin_percent); +- return rate >> BW_SCALE; ++ rate *= USEC_PER_SEC / 100 * (100 - margin); ++ rate >>= BW_SCALE; ++ rate = max(rate, 1ULL); ++ return rate; ++} ++ ++static u64 bbr_bw_bytes_per_sec(struct sock *sk, u64 rate) ++{ ++ return bbr_rate_bytes_per_sec(sk, rate, BBR_UNIT, 0); + } + + /* Convert a BBR bw and gain factor to a pacing rate in bytes per second. */ +@@ -257,12 +433,13 @@ static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) + { + u64 rate = bw; + +- rate = bbr_rate_bytes_per_sec(sk, rate, gain); ++ rate = bbr_rate_bytes_per_sec(sk, rate, gain, ++ bbr_pacing_margin_percent); + rate = min_t(u64, rate, READ_ONCE(sk->sk_max_pacing_rate)); + return rate; + } + +-/* Initialize pacing rate to: high_gain * init_cwnd / RTT. */ ++/* Initialize pacing rate to: startup_pacing_gain * init_cwnd / RTT. */ + static void bbr_init_pacing_rate_from_rtt(struct sock *sk) + { + struct tcp_sock *tp = tcp_sk(sk); +@@ -279,7 +456,7 @@ static void bbr_init_pacing_rate_from_rtt(struct sock *sk) + bw = (u64)tcp_snd_cwnd(tp) * BW_UNIT; + do_div(bw, rtt_us); + WRITE_ONCE(sk->sk_pacing_rate, +- bbr_bw_to_pacing_rate(sk, bw, bbr_high_gain)); ++ bbr_bw_to_pacing_rate(sk, bw, bbr_param(sk, startup_pacing_gain))); + } + + /* Pace using current bw estimate and a gain factor. */ +@@ -295,26 +472,48 @@ static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) + WRITE_ONCE(sk->sk_pacing_rate, rate); + } + +-/* override sysctl_tcp_min_tso_segs */ +-__bpf_kfunc static u32 bbr_min_tso_segs(struct sock *sk) ++/* Return the number of segments BBR would like in a TSO/GSO skb, given a ++ * particular max gso size as a constraint. TODO: make this simpler and more ++ * consistent by switching bbr to just call tcp_tso_autosize(). ++ */ ++static u32 bbr_tso_segs_generic(struct sock *sk, unsigned int mss_now, ++ u32 gso_max_size) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 segs, r; ++ u64 bytes; ++ ++ /* Budget a TSO/GSO burst size allowance based on bw (pacing_rate). */ ++ bytes = READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift); ++ ++ /* Budget a TSO/GSO burst size allowance based on min_rtt. For every ++ * K = 2^tso_rtt_shift microseconds of min_rtt, halve the burst. ++ * The min_rtt-based burst allowance is: 64 KBytes / 2^(min_rtt/K) ++ */ ++ if (bbr_param(sk, tso_rtt_shift)) { ++ r = bbr->min_rtt_us >> bbr_param(sk, tso_rtt_shift); ++ if (r < BITS_PER_TYPE(u32)) /* prevent undefined behavior */ ++ bytes += GSO_LEGACY_MAX_SIZE >> r; ++ } ++ ++ bytes = min_t(u32, bytes, gso_max_size - 1 - MAX_TCP_HEADER); ++ segs = max_t(u32, bytes / mss_now, ++ sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); ++ return segs; ++} ++ ++/* Custom tcp_tso_autosize() for BBR, used at transmit time to cap skb size. */ ++__bpf_kfunc static u32 bbr_tso_segs(struct sock *sk, unsigned int mss_now) + { +- return READ_ONCE(sk->sk_pacing_rate) < (bbr_min_tso_rate >> 3) ? 1 : 2; ++ return bbr_tso_segs_generic(sk, mss_now, sk->sk_gso_max_size); + } + ++/* Like bbr_tso_segs(), using mss_cache, ignoring driver's sk_gso_max_size. */ + static u32 bbr_tso_segs_goal(struct sock *sk) + { + struct tcp_sock *tp = tcp_sk(sk); +- u32 segs, bytes; +- +- /* Sort of tcp_tso_autosize() but ignoring +- * driver provided sk_gso_max_size. +- */ +- bytes = min_t(unsigned long, +- READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift), +- GSO_LEGACY_MAX_SIZE - 1 - MAX_TCP_HEADER); +- segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); + +- return min(segs, 0x7FU); ++ return bbr_tso_segs_generic(sk, tp->mss_cache, GSO_LEGACY_MAX_SIZE); + } + + /* Save "last known good" cwnd so we can restore it after losses or PROBE_RTT */ +@@ -334,7 +533,9 @@ __bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + +- if (event == CA_EVENT_TX_START && tp->app_limited) { ++ if (event == CA_EVENT_TX_START) { ++ if (!tp->app_limited) ++ return; + bbr->idle_restart = 1; + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; +@@ -345,6 +546,16 @@ __bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) + bbr_set_pacing_rate(sk, bbr_bw(sk), BBR_UNIT); + else if (bbr->mode == BBR_PROBE_RTT) + bbr_check_probe_rtt_done(sk); ++ } else if ((event == CA_EVENT_ECN_IS_CE || ++ event == CA_EVENT_ECN_NO_CE) && ++ bbr_can_use_ecn(sk) && ++ bbr_param(sk, precise_ece_ack)) { ++ u32 state = bbr->ce_state; ++ dctcp_ece_ack_update(sk, event, &bbr->prior_rcv_nxt, &state); ++ bbr->ce_state = state; ++ } else if (event == CA_EVENT_TLP_RECOVERY && ++ bbr_param(sk, loss_probe_recovery)) { ++ bbr_run_loss_probe_recovery(sk); + } + } + +@@ -367,10 +578,10 @@ static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) + * default. This should only happen when the connection is not using TCP + * timestamps and has retransmitted all of the SYN/SYNACK/data packets + * ACKed so far. In this case, an RTO can cut cwnd to 1, in which +- * case we need to slow-start up toward something safe: TCP_INIT_CWND. ++ * case we need to slow-start up toward something safe: initial cwnd. + */ + if (unlikely(bbr->min_rtt_us == ~0U)) /* no valid RTT samples yet? */ +- return TCP_INIT_CWND; /* be safe: cap at default initial cwnd*/ ++ return bbr->init_cwnd; /* be safe: cap at initial cwnd */ + + w = (u64)bw * bbr->min_rtt_us; + +@@ -387,23 +598,23 @@ static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) + * - one skb in sending host Qdisc, + * - one skb in sending host TSO/GSO engine + * - one skb being received by receiver host LRO/GRO/delayed-ACK engine +- * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because +- * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, ++ * Don't worry, at low rates this won't bloat cwnd because ++ * in such cases tso_segs_goal is small. The minimum cwnd is 4 packets, + * which allows 2 outstanding 2-packet sequences, to try to keep pipe + * full even with ACK-every-other-packet delayed ACKs. + */ + static u32 bbr_quantization_budget(struct sock *sk, u32 cwnd) + { + struct bbr *bbr = inet_csk_ca(sk); ++ u32 tso_segs_goal; + +- /* Allow enough full-sized skbs in flight to utilize end systems. */ +- cwnd += 3 * bbr_tso_segs_goal(sk); +- +- /* Reduce delayed ACKs by rounding up cwnd to the next even number. */ +- cwnd = (cwnd + 1) & ~1U; ++ tso_segs_goal = 3 * bbr_tso_segs_goal(sk); + ++ /* Allow enough full-sized skbs in flight to utilize end systems. */ ++ cwnd = max_t(u32, cwnd, tso_segs_goal); ++ cwnd = max_t(u32, cwnd, bbr_param(sk, cwnd_min_target)); + /* Ensure gain cycling gets inflight above BDP even for small BDPs. */ +- if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == 0) ++ if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) + cwnd += 2; + + return cwnd; +@@ -458,10 +669,10 @@ static u32 bbr_ack_aggregation_cwnd(struct sock *sk) + { + u32 max_aggr_cwnd, aggr_cwnd = 0; + +- if (bbr_extra_acked_gain && bbr_full_bw_reached(sk)) { ++ if (bbr_param(sk, extra_acked_gain)) { + max_aggr_cwnd = ((u64)bbr_bw(sk) * bbr_extra_acked_max_us) + / BW_UNIT; +- aggr_cwnd = (bbr_extra_acked_gain * bbr_extra_acked(sk)) ++ aggr_cwnd = (bbr_param(sk, extra_acked_gain) * bbr_extra_acked(sk)) + >> BBR_SCALE; + aggr_cwnd = min(aggr_cwnd, max_aggr_cwnd); + } +@@ -469,66 +680,27 @@ static u32 bbr_ack_aggregation_cwnd(struct sock *sk) + return aggr_cwnd; + } + +-/* An optimization in BBR to reduce losses: On the first round of recovery, we +- * follow the packet conservation principle: send P packets per P packets acked. +- * After that, we slow-start and send at most 2*P packets per P packets acked. +- * After recovery finishes, or upon undo, we restore the cwnd we had when +- * recovery started (capped by the target cwnd based on estimated BDP). +- * +- * TODO(ycheng/ncardwell): implement a rate-based approach. +- */ +-static bool bbr_set_cwnd_to_recover_or_restore( +- struct sock *sk, const struct rate_sample *rs, u32 acked, u32 *new_cwnd) ++/* Returns the cwnd for PROBE_RTT mode. */ ++static u32 bbr_probe_rtt_cwnd(struct sock *sk) + { +- struct tcp_sock *tp = tcp_sk(sk); +- struct bbr *bbr = inet_csk_ca(sk); +- u8 prev_state = bbr->prev_ca_state, state = inet_csk(sk)->icsk_ca_state; +- u32 cwnd = tcp_snd_cwnd(tp); +- +- /* An ACK for P pkts should release at most 2*P packets. We do this +- * in two steps. First, here we deduct the number of lost packets. +- * Then, in bbr_set_cwnd() we slow start up toward the target cwnd. +- */ +- if (rs->losses > 0) +- cwnd = max_t(s32, cwnd - rs->losses, 1); +- +- if (state == TCP_CA_Recovery && prev_state != TCP_CA_Recovery) { +- /* Starting 1st round of Recovery, so do packet conservation. */ +- bbr->packet_conservation = 1; +- bbr->next_rtt_delivered = tp->delivered; /* start round now */ +- /* Cut unused cwnd from app behavior, TSQ, or TSO deferral: */ +- cwnd = tcp_packets_in_flight(tp) + acked; +- } else if (prev_state >= TCP_CA_Recovery && state < TCP_CA_Recovery) { +- /* Exiting loss recovery; restore cwnd saved before recovery. */ +- cwnd = max(cwnd, bbr->prior_cwnd); +- bbr->packet_conservation = 0; +- } +- bbr->prev_ca_state = state; +- +- if (bbr->packet_conservation) { +- *new_cwnd = max(cwnd, tcp_packets_in_flight(tp) + acked); +- return true; /* yes, using packet conservation */ +- } +- *new_cwnd = cwnd; +- return false; ++ return max_t(u32, bbr_param(sk, cwnd_min_target), ++ bbr_bdp(sk, bbr_bw(sk), bbr_param(sk, probe_rtt_cwnd_gain))); + } + + /* Slow-start up toward target cwnd (if bw estimate is growing, or packet loss + * has drawn us down below target), or snap down to target if we're above it. + */ + static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, +- u32 acked, u32 bw, int gain) ++ u32 acked, u32 bw, int gain, u32 cwnd, ++ struct bbr_context *ctx) + { + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); +- u32 cwnd = tcp_snd_cwnd(tp), target_cwnd = 0; ++ u32 target_cwnd = 0; + + if (!acked) + goto done; /* no packet fully ACKed; just apply caps */ + +- if (bbr_set_cwnd_to_recover_or_restore(sk, rs, acked, &cwnd)) +- goto done; +- + target_cwnd = bbr_bdp(sk, bw, gain); + + /* Increment the cwnd to account for excess ACKed data that seems +@@ -537,74 +709,26 @@ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, + target_cwnd += bbr_ack_aggregation_cwnd(sk); + target_cwnd = bbr_quantization_budget(sk, target_cwnd); + +- /* If we're below target cwnd, slow start cwnd toward target cwnd. */ +- if (bbr_full_bw_reached(sk)) /* only cut cwnd if we filled the pipe */ +- cwnd = min(cwnd + acked, target_cwnd); +- else if (cwnd < target_cwnd || tp->delivered < TCP_INIT_CWND) +- cwnd = cwnd + acked; +- cwnd = max(cwnd, bbr_cwnd_min_target); ++ /* Update cwnd and enable fast path if cwnd reaches target_cwnd. */ ++ bbr->try_fast_path = 0; ++ if (bbr_full_bw_reached(sk)) { /* only cut cwnd if we filled the pipe */ ++ cwnd += acked; ++ if (cwnd >= target_cwnd) { ++ cwnd = target_cwnd; ++ bbr->try_fast_path = 1; ++ } ++ } else if (cwnd < target_cwnd || cwnd < 2 * bbr->init_cwnd) { ++ cwnd += acked; ++ } else { ++ bbr->try_fast_path = 1; ++ } + ++ cwnd = max_t(u32, cwnd, bbr_param(sk, cwnd_min_target)); + done: +- tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* apply global cap */ ++ tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* global cap */ + if (bbr->mode == BBR_PROBE_RTT) /* drain queue, refresh min_rtt */ +- tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), bbr_cwnd_min_target)); +-} +- +-/* End cycle phase if it's time and/or we hit the phase's in-flight target. */ +-static bool bbr_is_next_cycle_phase(struct sock *sk, +- const struct rate_sample *rs) +-{ +- struct tcp_sock *tp = tcp_sk(sk); +- struct bbr *bbr = inet_csk_ca(sk); +- bool is_full_length = +- tcp_stamp_us_delta(tp->delivered_mstamp, bbr->cycle_mstamp) > +- bbr->min_rtt_us; +- u32 inflight, bw; +- +- /* The pacing_gain of 1.0 paces at the estimated bw to try to fully +- * use the pipe without increasing the queue. +- */ +- if (bbr->pacing_gain == BBR_UNIT) +- return is_full_length; /* just use wall clock time */ +- +- inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); +- bw = bbr_max_bw(sk); +- +- /* A pacing_gain > 1.0 probes for bw by trying to raise inflight to at +- * least pacing_gain*BDP; this may take more than min_rtt if min_rtt is +- * small (e.g. on a LAN). We do not persist if packets are lost, since +- * a path with small buffers may not hold that much. +- */ +- if (bbr->pacing_gain > BBR_UNIT) +- return is_full_length && +- (rs->losses || /* perhaps pacing_gain*BDP won't fit */ +- inflight >= bbr_inflight(sk, bw, bbr->pacing_gain)); +- +- /* A pacing_gain < 1.0 tries to drain extra queue we added if bw +- * probing didn't find more bw. If inflight falls to match BDP then we +- * estimate queue is drained; persisting would underutilize the pipe. +- */ +- return is_full_length || +- inflight <= bbr_inflight(sk, bw, BBR_UNIT); +-} +- +-static void bbr_advance_cycle_phase(struct sock *sk) +-{ +- struct tcp_sock *tp = tcp_sk(sk); +- struct bbr *bbr = inet_csk_ca(sk); +- +- bbr->cycle_idx = (bbr->cycle_idx + 1) & (CYCLE_LEN - 1); +- bbr->cycle_mstamp = tp->delivered_mstamp; +-} +- +-/* Gain cycling: cycle pacing gain to converge to fair share of available bw. */ +-static void bbr_update_cycle_phase(struct sock *sk, +- const struct rate_sample *rs) +-{ +- struct bbr *bbr = inet_csk_ca(sk); +- +- if (bbr->mode == BBR_PROBE_BW && bbr_is_next_cycle_phase(sk, rs)) +- bbr_advance_cycle_phase(sk); ++ tcp_snd_cwnd_set(tp, min_t(u32, tcp_snd_cwnd(tp), ++ bbr_probe_rtt_cwnd(sk))); + } + + static void bbr_reset_startup_mode(struct sock *sk) +@@ -614,191 +738,49 @@ static void bbr_reset_startup_mode(struct sock *sk) + bbr->mode = BBR_STARTUP; + } + +-static void bbr_reset_probe_bw_mode(struct sock *sk) +-{ +- struct bbr *bbr = inet_csk_ca(sk); +- +- bbr->mode = BBR_PROBE_BW; +- bbr->cycle_idx = CYCLE_LEN - 1 - get_random_u32_below(bbr_cycle_rand); +- bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ +-} +- +-static void bbr_reset_mode(struct sock *sk) +-{ +- if (!bbr_full_bw_reached(sk)) +- bbr_reset_startup_mode(sk); +- else +- bbr_reset_probe_bw_mode(sk); +-} +- +-/* Start a new long-term sampling interval. */ +-static void bbr_reset_lt_bw_sampling_interval(struct sock *sk) +-{ +- struct tcp_sock *tp = tcp_sk(sk); +- struct bbr *bbr = inet_csk_ca(sk); +- +- bbr->lt_last_stamp = div_u64(tp->delivered_mstamp, USEC_PER_MSEC); +- bbr->lt_last_delivered = tp->delivered; +- bbr->lt_last_lost = tp->lost; +- bbr->lt_rtt_cnt = 0; +-} +- +-/* Completely reset long-term bandwidth sampling. */ +-static void bbr_reset_lt_bw_sampling(struct sock *sk) +-{ +- struct bbr *bbr = inet_csk_ca(sk); +- +- bbr->lt_bw = 0; +- bbr->lt_use_bw = 0; +- bbr->lt_is_sampling = false; +- bbr_reset_lt_bw_sampling_interval(sk); +-} +- +-/* Long-term bw sampling interval is done. Estimate whether we're policed. */ +-static void bbr_lt_bw_interval_done(struct sock *sk, u32 bw) +-{ +- struct bbr *bbr = inet_csk_ca(sk); +- u32 diff; +- +- if (bbr->lt_bw) { /* do we have bw from a previous interval? */ +- /* Is new bw close to the lt_bw from the previous interval? */ +- diff = abs(bw - bbr->lt_bw); +- if ((diff * BBR_UNIT <= bbr_lt_bw_ratio * bbr->lt_bw) || +- (bbr_rate_bytes_per_sec(sk, diff, BBR_UNIT) <= +- bbr_lt_bw_diff)) { +- /* All criteria are met; estimate we're policed. */ +- bbr->lt_bw = (bw + bbr->lt_bw) >> 1; /* avg 2 intvls */ +- bbr->lt_use_bw = 1; +- bbr->pacing_gain = BBR_UNIT; /* try to avoid drops */ +- bbr->lt_rtt_cnt = 0; +- return; +- } +- } +- bbr->lt_bw = bw; +- bbr_reset_lt_bw_sampling_interval(sk); +-} +- +-/* Token-bucket traffic policers are common (see "An Internet-Wide Analysis of +- * Traffic Policing", SIGCOMM 2016). BBR detects token-bucket policers and +- * explicitly models their policed rate, to reduce unnecessary losses. We +- * estimate that we're policed if we see 2 consecutive sampling intervals with +- * consistent throughput and high packet loss. If we think we're being policed, +- * set lt_bw to the "long-term" average delivery rate from those 2 intervals. ++/* See if we have reached next round trip. Upon start of the new round, ++ * returns packets delivered since previous round start plus this ACK. + */ +-static void bbr_lt_bw_sampling(struct sock *sk, const struct rate_sample *rs) +-{ +- struct tcp_sock *tp = tcp_sk(sk); +- struct bbr *bbr = inet_csk_ca(sk); +- u32 lost, delivered; +- u64 bw; +- u32 t; +- +- if (bbr->lt_use_bw) { /* already using long-term rate, lt_bw? */ +- if (bbr->mode == BBR_PROBE_BW && bbr->round_start && +- ++bbr->lt_rtt_cnt >= bbr_lt_bw_max_rtts) { +- bbr_reset_lt_bw_sampling(sk); /* stop using lt_bw */ +- bbr_reset_probe_bw_mode(sk); /* restart gain cycling */ +- } +- return; +- } +- +- /* Wait for the first loss before sampling, to let the policer exhaust +- * its tokens and estimate the steady-state rate allowed by the policer. +- * Starting samples earlier includes bursts that over-estimate the bw. +- */ +- if (!bbr->lt_is_sampling) { +- if (!rs->losses) +- return; +- bbr_reset_lt_bw_sampling_interval(sk); +- bbr->lt_is_sampling = true; +- } +- +- /* To avoid underestimates, reset sampling if we run out of data. */ +- if (rs->is_app_limited) { +- bbr_reset_lt_bw_sampling(sk); +- return; +- } +- +- if (bbr->round_start) +- bbr->lt_rtt_cnt++; /* count round trips in this interval */ +- if (bbr->lt_rtt_cnt < bbr_lt_intvl_min_rtts) +- return; /* sampling interval needs to be longer */ +- if (bbr->lt_rtt_cnt > 4 * bbr_lt_intvl_min_rtts) { +- bbr_reset_lt_bw_sampling(sk); /* interval is too long */ +- return; +- } +- +- /* End sampling interval when a packet is lost, so we estimate the +- * policer tokens were exhausted. Stopping the sampling before the +- * tokens are exhausted under-estimates the policed rate. +- */ +- if (!rs->losses) +- return; +- +- /* Calculate packets lost and delivered in sampling interval. */ +- lost = tp->lost - bbr->lt_last_lost; +- delivered = tp->delivered - bbr->lt_last_delivered; +- /* Is loss rate (lost/delivered) >= lt_loss_thresh? If not, wait. */ +- if (!delivered || (lost << BBR_SCALE) < bbr_lt_loss_thresh * delivered) +- return; +- +- /* Find average delivery rate in this sampling interval. */ +- t = div_u64(tp->delivered_mstamp, USEC_PER_MSEC) - bbr->lt_last_stamp; +- if ((s32)t < 1) +- return; /* interval is less than one ms, so wait */ +- /* Check if can multiply without overflow */ +- if (t >= ~0U / USEC_PER_MSEC) { +- bbr_reset_lt_bw_sampling(sk); /* interval too long; reset */ +- return; +- } +- t *= USEC_PER_MSEC; +- bw = (u64)delivered * BW_UNIT; +- do_div(bw, t); +- bbr_lt_bw_interval_done(sk, bw); +-} +- +-/* Estimate the bandwidth based on how fast packets are delivered */ +-static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) ++static u32 bbr_update_round_start(struct sock *sk, ++ const struct rate_sample *rs, struct bbr_context *ctx) + { + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); +- u64 bw; ++ u32 round_delivered = 0; + + bbr->round_start = 0; +- if (rs->delivered < 0 || rs->interval_us <= 0) +- return; /* Not a valid observation */ + + /* See if we've reached the next RTT */ +- if (!before(rs->prior_delivered, bbr->next_rtt_delivered)) { ++ if (rs->interval_us > 0 && ++ !before(rs->prior_delivered, bbr->next_rtt_delivered)) { ++ round_delivered = tp->delivered - bbr->next_rtt_delivered; + bbr->next_rtt_delivered = tp->delivered; +- bbr->rtt_cnt++; + bbr->round_start = 1; +- bbr->packet_conservation = 0; + } ++ return round_delivered; ++} + +- bbr_lt_bw_sampling(sk, rs); ++/* Calculate the bandwidth based on how fast packets are delivered */ ++static void bbr_calculate_bw_sample(struct sock *sk, ++ const struct rate_sample *rs, struct bbr_context *ctx) ++{ ++ u64 bw = 0; + + /* Divide delivered by the interval to find a (lower bound) bottleneck + * bandwidth sample. Delivered is in packets and interval_us in uS and + * ratio will be <<1 for most connections. So delivered is first scaled. ++ * Round up to allow growth at low rates, even with integer division. + */ +- bw = div64_long((u64)rs->delivered * BW_UNIT, rs->interval_us); +- +- /* If this sample is application-limited, it is likely to have a very +- * low delivered count that represents application behavior rather than +- * the available network rate. Such a sample could drag down estimated +- * bw, causing needless slow-down. Thus, to continue to send at the +- * last measured network rate, we filter out app-limited samples unless +- * they describe the path bw at least as well as our bw model. +- * +- * So the goal during app-limited phase is to proceed with the best +- * network rate no matter how long. We automatically leave this +- * phase when app writes faster than the network can deliver :) +- */ +- if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) { +- /* Incorporate new sample into our max bw filter. */ +- minmax_running_max(&bbr->bw, bbr_bw_rtts, bbr->rtt_cnt, bw); ++ if (rs->interval_us > 0) { ++ if (WARN_ONCE(rs->delivered < 0, ++ "negative delivered: %d interval_us: %ld\n", ++ rs->delivered, rs->interval_us)) ++ return; ++ ++ bw = DIV_ROUND_UP_ULL((u64)rs->delivered * BW_UNIT, rs->interval_us); + } ++ ++ ctx->sample_bw = bw; + } + + /* Estimates the windowed max degree of ack aggregation. +@@ -812,7 +794,7 @@ static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) + * + * Max extra_acked is clamped by cwnd and bw * bbr_extra_acked_max_us (100 ms). + * Max filter is an approximate sliding window of 5-10 (packet timed) round +- * trips. ++ * trips for non-startup phase, and 1-2 round trips for startup. + */ + static void bbr_update_ack_aggregation(struct sock *sk, + const struct rate_sample *rs) +@@ -820,15 +802,19 @@ static void bbr_update_ack_aggregation(struct sock *sk, + u32 epoch_us, expected_acked, extra_acked; + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); ++ u32 extra_acked_win_rtts_thresh = bbr_param(sk, extra_acked_win_rtts); + +- if (!bbr_extra_acked_gain || rs->acked_sacked <= 0 || ++ if (!bbr_param(sk, extra_acked_gain) || rs->acked_sacked <= 0 || + rs->delivered < 0 || rs->interval_us <= 0) + return; + + if (bbr->round_start) { + bbr->extra_acked_win_rtts = min(0x1F, + bbr->extra_acked_win_rtts + 1); +- if (bbr->extra_acked_win_rtts >= bbr_extra_acked_win_rtts) { ++ if (!bbr_full_bw_reached(sk)) ++ extra_acked_win_rtts_thresh = 1; ++ if (bbr->extra_acked_win_rtts >= ++ extra_acked_win_rtts_thresh) { + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = bbr->extra_acked_win_idx ? + 0 : 1; +@@ -862,49 +848,6 @@ static void bbr_update_ack_aggregation(struct sock *sk, + bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; + } + +-/* Estimate when the pipe is full, using the change in delivery rate: BBR +- * estimates that STARTUP filled the pipe if the estimated bw hasn't changed by +- * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited +- * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the +- * higher rwin, 3: we get higher delivery rate samples. Or transient +- * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar +- * design goal, but uses delay and inter-ACK spacing instead of bandwidth. +- */ +-static void bbr_check_full_bw_reached(struct sock *sk, +- const struct rate_sample *rs) +-{ +- struct bbr *bbr = inet_csk_ca(sk); +- u32 bw_thresh; +- +- if (bbr_full_bw_reached(sk) || !bbr->round_start || rs->is_app_limited) +- return; +- +- bw_thresh = (u64)bbr->full_bw * bbr_full_bw_thresh >> BBR_SCALE; +- if (bbr_max_bw(sk) >= bw_thresh) { +- bbr->full_bw = bbr_max_bw(sk); +- bbr->full_bw_cnt = 0; +- return; +- } +- ++bbr->full_bw_cnt; +- bbr->full_bw_reached = bbr->full_bw_cnt >= bbr_full_bw_cnt; +-} +- +-/* If pipe is probably full, drain the queue and then enter steady-state. */ +-static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs) +-{ +- struct bbr *bbr = inet_csk_ca(sk); +- +- if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { +- bbr->mode = BBR_DRAIN; /* drain queue we created */ +- tcp_sk(sk)->snd_ssthresh = +- bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); +- } /* fall through to check if in-flight is already small: */ +- if (bbr->mode == BBR_DRAIN && +- bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= +- bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) +- bbr_reset_probe_bw_mode(sk); /* we estimate queue is drained */ +-} +- + static void bbr_check_probe_rtt_done(struct sock *sk) + { + struct tcp_sock *tp = tcp_sk(sk); +@@ -914,9 +857,9 @@ static void bbr_check_probe_rtt_done(struct sock *sk) + after(tcp_jiffies32, bbr->probe_rtt_done_stamp))) + return; + +- bbr->min_rtt_stamp = tcp_jiffies32; /* wait a while until PROBE_RTT */ ++ bbr->probe_rtt_min_stamp = tcp_jiffies32; /* schedule next PROBE_RTT */ + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); +- bbr_reset_mode(sk); ++ bbr_exit_probe_rtt(sk); + } + + /* The goal of PROBE_RTT mode is to have BBR flows cooperatively and +@@ -942,23 +885,35 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) + { + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); +- bool filter_expired; ++ bool probe_rtt_expired, min_rtt_expired; ++ u32 expire; + +- /* Track min RTT seen in the min_rtt_win_sec filter window: */ +- filter_expired = after(tcp_jiffies32, +- bbr->min_rtt_stamp + bbr_min_rtt_win_sec * HZ); ++ /* Track min RTT in probe_rtt_win_ms to time next PROBE_RTT state. */ ++ expire = bbr->probe_rtt_min_stamp + ++ msecs_to_jiffies(bbr_param(sk, probe_rtt_win_ms)); ++ probe_rtt_expired = after(tcp_jiffies32, expire); + if (rs->rtt_us >= 0 && +- (rs->rtt_us < bbr->min_rtt_us || +- (filter_expired && !rs->is_ack_delayed))) { +- bbr->min_rtt_us = rs->rtt_us; +- bbr->min_rtt_stamp = tcp_jiffies32; ++ (rs->rtt_us < bbr->probe_rtt_min_us || ++ (probe_rtt_expired && !rs->is_ack_delayed))) { ++ bbr->probe_rtt_min_us = rs->rtt_us; ++ bbr->probe_rtt_min_stamp = tcp_jiffies32; ++ } ++ /* Track min RTT seen in the min_rtt_win_sec filter window: */ ++ expire = bbr->min_rtt_stamp + bbr_param(sk, min_rtt_win_sec) * HZ; ++ min_rtt_expired = after(tcp_jiffies32, expire); ++ if (bbr->probe_rtt_min_us <= bbr->min_rtt_us || ++ min_rtt_expired) { ++ bbr->min_rtt_us = bbr->probe_rtt_min_us; ++ bbr->min_rtt_stamp = bbr->probe_rtt_min_stamp; + } + +- if (bbr_probe_rtt_mode_ms > 0 && filter_expired && ++ if (bbr_param(sk, probe_rtt_mode_ms) > 0 && probe_rtt_expired && + !bbr->idle_restart && bbr->mode != BBR_PROBE_RTT) { + bbr->mode = BBR_PROBE_RTT; /* dip, drain queue */ + bbr_save_cwnd(sk); /* note cwnd so we can restore it */ + bbr->probe_rtt_done_stamp = 0; ++ bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; ++ bbr->next_rtt_delivered = tp->delivered; + } + + if (bbr->mode == BBR_PROBE_RTT) { +@@ -967,9 +922,9 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) + (tp->delivered + tcp_packets_in_flight(tp)) ? : 1; + /* Maintain min packets in flight for max(200 ms, 1 round). */ + if (!bbr->probe_rtt_done_stamp && +- tcp_packets_in_flight(tp) <= bbr_cwnd_min_target) { ++ tcp_packets_in_flight(tp) <= bbr_probe_rtt_cwnd(sk)) { + bbr->probe_rtt_done_stamp = tcp_jiffies32 + +- msecs_to_jiffies(bbr_probe_rtt_mode_ms); ++ msecs_to_jiffies(bbr_param(sk, probe_rtt_mode_ms)); + bbr->probe_rtt_round_done = 0; + bbr->next_rtt_delivered = tp->delivered; + } else if (bbr->probe_rtt_done_stamp) { +@@ -990,18 +945,20 @@ static void bbr_update_gains(struct sock *sk) + + switch (bbr->mode) { + case BBR_STARTUP: +- bbr->pacing_gain = bbr_high_gain; +- bbr->cwnd_gain = bbr_high_gain; ++ bbr->pacing_gain = bbr_param(sk, startup_pacing_gain); ++ bbr->cwnd_gain = bbr_param(sk, startup_cwnd_gain); + break; + case BBR_DRAIN: +- bbr->pacing_gain = bbr_drain_gain; /* slow, to drain */ +- bbr->cwnd_gain = bbr_high_gain; /* keep cwnd */ ++ bbr->pacing_gain = bbr_param(sk, drain_gain); /* slow, to drain */ ++ bbr->cwnd_gain = bbr_param(sk, startup_cwnd_gain); /* keep cwnd */ + break; + case BBR_PROBE_BW: +- bbr->pacing_gain = (bbr->lt_use_bw ? +- BBR_UNIT : +- bbr_pacing_gain[bbr->cycle_idx]); +- bbr->cwnd_gain = bbr_cwnd_gain; ++ bbr->pacing_gain = bbr_pacing_gain[bbr->cycle_idx]; ++ bbr->cwnd_gain = bbr_param(sk, cwnd_gain); ++ if (bbr_param(sk, bw_probe_cwnd_gain) && ++ bbr->cycle_idx == BBR_BW_PROBE_UP) ++ bbr->cwnd_gain += ++ BBR_UNIT * bbr_param(sk, bw_probe_cwnd_gain) / 4; + break; + case BBR_PROBE_RTT: + bbr->pacing_gain = BBR_UNIT; +@@ -1013,144 +970,1387 @@ static void bbr_update_gains(struct sock *sk) + } + } + +-static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) ++__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) + { +- bbr_update_bw(sk, rs); +- bbr_update_ack_aggregation(sk, rs); +- bbr_update_cycle_phase(sk, rs); +- bbr_check_full_bw_reached(sk, rs); +- bbr_check_drain(sk, rs); +- bbr_update_min_rtt(sk, rs); +- bbr_update_gains(sk); ++ /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ ++ return 3; + } + +-__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) ++/* Incorporate a new bw sample into the current window of our max filter. */ ++static void bbr_take_max_bw_sample(struct sock *sk, u32 bw) + { + struct bbr *bbr = inet_csk_ca(sk); +- u32 bw; +- +- bbr_update_model(sk, rs); + +- bw = bbr_bw(sk); +- bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); +- bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain); ++ bbr->bw_hi[1] = max(bw, bbr->bw_hi[1]); + } + +-__bpf_kfunc static void bbr_init(struct sock *sk) ++/* Keep max of last 1-2 cycles. Each PROBE_BW cycle, flip filter window. */ ++static void bbr_advance_max_bw_filter(struct sock *sk) + { +- struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + +- bbr->prior_cwnd = 0; +- tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; +- bbr->rtt_cnt = 0; +- bbr->next_rtt_delivered = tp->delivered; +- bbr->prev_ca_state = TCP_CA_Open; +- bbr->packet_conservation = 0; +- +- bbr->probe_rtt_done_stamp = 0; +- bbr->probe_rtt_round_done = 0; +- bbr->min_rtt_us = tcp_min_rtt(tp); +- bbr->min_rtt_stamp = tcp_jiffies32; +- +- minmax_reset(&bbr->bw, bbr->rtt_cnt, 0); /* init max bw to 0 */ ++ if (!bbr->bw_hi[1]) ++ return; /* no samples in this window; remember old window */ ++ bbr->bw_hi[0] = bbr->bw_hi[1]; ++ bbr->bw_hi[1] = 0; ++} + +- bbr->has_seen_rtt = 0; +- bbr_init_pacing_rate_from_rtt(sk); ++/* Reset the estimator for reaching full bandwidth based on bw plateau. */ ++static void bbr_reset_full_bw(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); + +- bbr->round_start = 0; +- bbr->idle_restart = 0; +- bbr->full_bw_reached = 0; + bbr->full_bw = 0; + bbr->full_bw_cnt = 0; +- bbr->cycle_mstamp = 0; +- bbr->cycle_idx = 0; +- bbr_reset_lt_bw_sampling(sk); +- bbr_reset_startup_mode(sk); ++ bbr->full_bw_now = 0; ++} + +- bbr->ack_epoch_mstamp = tp->tcp_mstamp; +- bbr->ack_epoch_acked = 0; +- bbr->extra_acked_win_rtts = 0; +- bbr->extra_acked_win_idx = 0; +- bbr->extra_acked[0] = 0; +- bbr->extra_acked[1] = 0; ++/* How much do we want in flight? Our BDP, unless congestion cut cwnd. */ ++static u32 bbr_target_inflight(struct sock *sk) ++{ ++ u32 bdp = bbr_inflight(sk, bbr_bw(sk), BBR_UNIT); + +- cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); ++ return min(bdp, tcp_sk(sk)->snd_cwnd); + } + +-__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) ++static bool bbr_is_probing_bandwidth(struct sock *sk) + { +- /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ +- return 3; ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ return (bbr->mode == BBR_STARTUP) || ++ (bbr->mode == BBR_PROBE_BW && ++ (bbr->cycle_idx == BBR_BW_PROBE_REFILL || ++ bbr->cycle_idx == BBR_BW_PROBE_UP)); ++} ++ ++/* Has the given amount of time elapsed since we marked the phase start? */ ++static bool bbr_has_elapsed_in_phase(const struct sock *sk, u32 interval_us) ++{ ++ const struct tcp_sock *tp = tcp_sk(sk); ++ const struct bbr *bbr = inet_csk_ca(sk); ++ ++ return tcp_stamp_us_delta(tp->tcp_mstamp, ++ bbr->cycle_mstamp + interval_us) > 0; ++} ++ ++static void bbr_handle_queue_too_high_in_startup(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 bdp; /* estimated BDP in packets, with quantization budget */ ++ ++ bbr->full_bw_reached = 1; ++ ++ bdp = bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); ++ bbr->inflight_hi = max(bdp, bbr->inflight_latest); ++} ++ ++/* Exit STARTUP upon N consecutive rounds with ECN mark rate > ecn_thresh. */ ++static void bbr_check_ecn_too_high_in_startup(struct sock *sk, u32 ce_ratio) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ if (bbr_full_bw_reached(sk) || !bbr->ecn_eligible || ++ !bbr_param(sk, full_ecn_cnt) || !bbr_param(sk, ecn_thresh)) ++ return; ++ ++ if (ce_ratio >= bbr_param(sk, ecn_thresh)) ++ bbr->startup_ecn_rounds++; ++ else ++ bbr->startup_ecn_rounds = 0; ++ ++ if (bbr->startup_ecn_rounds >= bbr_param(sk, full_ecn_cnt)) { ++ bbr_handle_queue_too_high_in_startup(sk); ++ return; ++ } ++} ++ ++/* Updates ecn_alpha and returns ce_ratio. -1 if not available. */ ++static int bbr_update_ecn_alpha(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct net *net = sock_net(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ s32 delivered, delivered_ce; ++ u64 alpha, ce_ratio; ++ u32 gain; ++ bool want_ecn_alpha; ++ ++ /* See if we should use ECN sender logic for this connection. */ ++ if (!bbr->ecn_eligible && bbr_can_use_ecn(sk) && ++ bbr_param(sk, ecn_factor) && ++ (bbr->min_rtt_us <= bbr_ecn_max_rtt_us || ++ !bbr_ecn_max_rtt_us)) ++ bbr->ecn_eligible = 1; ++ ++ /* Skip updating alpha only if not ECN-eligible and PLB is disabled. */ ++ want_ecn_alpha = (bbr->ecn_eligible || ++ (bbr_can_use_ecn(sk) && ++ READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled))); ++ if (!want_ecn_alpha) ++ return -1; ++ ++ delivered = tp->delivered - bbr->alpha_last_delivered; ++ delivered_ce = tp->delivered_ce - bbr->alpha_last_delivered_ce; ++ ++ if (delivered == 0 || /* avoid divide by zero */ ++ WARN_ON_ONCE(delivered < 0 || delivered_ce < 0)) /* backwards? */ ++ return -1; ++ ++ BUILD_BUG_ON(BBR_SCALE != TCP_PLB_SCALE); ++ ce_ratio = (u64)delivered_ce << BBR_SCALE; ++ do_div(ce_ratio, delivered); ++ ++ gain = bbr_param(sk, ecn_alpha_gain); ++ alpha = ((BBR_UNIT - gain) * bbr->ecn_alpha) >> BBR_SCALE; ++ alpha += (gain * ce_ratio) >> BBR_SCALE; ++ bbr->ecn_alpha = min_t(u32, alpha, BBR_UNIT); ++ ++ bbr->alpha_last_delivered = tp->delivered; ++ bbr->alpha_last_delivered_ce = tp->delivered_ce; ++ ++ bbr_check_ecn_too_high_in_startup(sk, ce_ratio); ++ return (int)ce_ratio; + } + +-/* In theory BBR does not need to undo the cwnd since it does not +- * always reduce cwnd on losses (see bbr_main()). Keep it for now. ++/* Protective Load Balancing (PLB). PLB rehashes outgoing data (to a new IPv6 ++ * flow label) if it encounters sustained congestion in the form of ECN marks. + */ +-__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) ++static void bbr_plb(struct sock *sk, const struct rate_sample *rs, int ce_ratio) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ if (bbr->round_start && ce_ratio >= 0) ++ tcp_plb_update_state(sk, &bbr->plb, ce_ratio); ++ ++ tcp_plb_check_rehash(sk, &bbr->plb); ++} ++ ++/* Each round trip of BBR_BW_PROBE_UP, double volume of probing data. */ ++static void bbr_raise_inflight_hi_slope(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 growth_this_round, cnt; ++ ++ /* Calculate "slope": packets S/Acked per inflight_hi increment. */ ++ growth_this_round = 1 << bbr->bw_probe_up_rounds; ++ bbr->bw_probe_up_rounds = min(bbr->bw_probe_up_rounds + 1, 30); ++ cnt = tcp_snd_cwnd(tp) / growth_this_round; ++ cnt = max(cnt, 1U); ++ bbr->bw_probe_up_cnt = cnt; ++} ++ ++/* In BBR_BW_PROBE_UP, not seeing high loss/ECN/queue, so raise inflight_hi. */ ++static void bbr_probe_inflight_hi_upward(struct sock *sk, ++ const struct rate_sample *rs) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 delta; ++ ++ if (!tp->is_cwnd_limited || tcp_snd_cwnd(tp) < bbr->inflight_hi) ++ return; /* not fully using inflight_hi, so don't grow it */ ++ ++ /* For each bw_probe_up_cnt packets ACKed, increase inflight_hi by 1. */ ++ bbr->bw_probe_up_acks += rs->acked_sacked; ++ if (bbr->bw_probe_up_acks >= bbr->bw_probe_up_cnt) { ++ delta = bbr->bw_probe_up_acks / bbr->bw_probe_up_cnt; ++ bbr->bw_probe_up_acks -= delta * bbr->bw_probe_up_cnt; ++ bbr->inflight_hi += delta; ++ bbr->try_fast_path = 0; /* Need to update cwnd */ ++ } ++ ++ if (bbr->round_start) ++ bbr_raise_inflight_hi_slope(sk); ++} ++ ++/* Does loss/ECN rate for this sample say inflight is "too high"? ++ * This is used by both the bbr_check_loss_too_high_in_startup() function, ++ * which can be used in either v1 or v2, and the PROBE_UP phase of v2, which ++ * uses it to notice when loss/ECN rates suggest inflight is too high. ++ */ ++static bool bbr_is_inflight_too_high(const struct sock *sk, ++ const struct rate_sample *rs) ++{ ++ const struct bbr *bbr = inet_csk_ca(sk); ++ u32 loss_thresh, ecn_thresh; ++ ++ if (rs->lost > 0 && rs->tx_in_flight) { ++ loss_thresh = (u64)rs->tx_in_flight * bbr_param(sk, loss_thresh) >> ++ BBR_SCALE; ++ if (rs->lost > loss_thresh) { ++ return true; ++ } ++ } ++ ++ if (rs->delivered_ce > 0 && rs->delivered > 0 && ++ bbr->ecn_eligible && bbr_param(sk, ecn_thresh)) { ++ ecn_thresh = (u64)rs->delivered * bbr_param(sk, ecn_thresh) >> ++ BBR_SCALE; ++ if (rs->delivered_ce > ecn_thresh) { ++ return true; ++ } ++ } ++ ++ return false; ++} ++ ++/* Calculate the tx_in_flight level that corresponded to excessive loss. ++ * We find "lost_prefix" segs of the skb where loss rate went too high, ++ * by solving for "lost_prefix" in the following equation: ++ * lost / inflight >= loss_thresh ++ * (lost_prev + lost_prefix) / (inflight_prev + lost_prefix) >= loss_thresh ++ * Then we take that equation, convert it to fixed point, and ++ * round up to the nearest packet. ++ */ ++static u32 bbr_inflight_hi_from_lost_skb(const struct sock *sk, ++ const struct rate_sample *rs, ++ const struct sk_buff *skb) ++{ ++ const struct tcp_sock *tp = tcp_sk(sk); ++ u32 loss_thresh = bbr_param(sk, loss_thresh); ++ u32 pcount, divisor, inflight_hi; ++ s32 inflight_prev, lost_prev; ++ u64 loss_budget, lost_prefix; ++ ++ pcount = tcp_skb_pcount(skb); ++ ++ /* How much data was in flight before this skb? */ ++ inflight_prev = rs->tx_in_flight - pcount; ++ if (inflight_prev < 0) { ++ WARN_ONCE(tcp_skb_tx_in_flight_is_suspicious( ++ pcount, ++ TCP_SKB_CB(skb)->sacked, ++ rs->tx_in_flight), ++ "tx_in_flight: %u pcount: %u reneg: %u", ++ rs->tx_in_flight, pcount, tcp_sk(sk)->is_sack_reneg); ++ return ~0U; ++ } ++ ++ /* How much inflight data was marked lost before this skb? */ ++ lost_prev = rs->lost - pcount; ++ if (WARN_ONCE(lost_prev < 0, ++ "cwnd: %u ca: %d out: %u lost: %u pif: %u " ++ "tx_in_flight: %u tx.lost: %u tp->lost: %u rs->lost: %d " ++ "lost_prev: %d pcount: %d seq: %u end_seq: %u reneg: %u", ++ tcp_snd_cwnd(tp), inet_csk(sk)->icsk_ca_state, ++ tp->packets_out, tp->lost_out, tcp_packets_in_flight(tp), ++ rs->tx_in_flight, TCP_SKB_CB(skb)->tx.lost, tp->lost, ++ rs->lost, lost_prev, pcount, ++ TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, ++ tp->is_sack_reneg)) ++ return ~0U; ++ ++ /* At what prefix of this lost skb did losss rate exceed loss_thresh? */ ++ loss_budget = (u64)inflight_prev * loss_thresh + BBR_UNIT - 1; ++ loss_budget >>= BBR_SCALE; ++ if (lost_prev >= loss_budget) { ++ lost_prefix = 0; /* previous losses crossed loss_thresh */ ++ } else { ++ lost_prefix = loss_budget - lost_prev; ++ lost_prefix <<= BBR_SCALE; ++ divisor = BBR_UNIT - loss_thresh; ++ if (WARN_ON_ONCE(!divisor)) /* loss_thresh is 8 bits */ ++ return ~0U; ++ do_div(lost_prefix, divisor); ++ } ++ ++ inflight_hi = inflight_prev + lost_prefix; ++ return inflight_hi; ++} ++ ++/* If loss/ECN rates during probing indicated we may have overfilled a ++ * buffer, return an operating point that tries to leave unutilized headroom in ++ * the path for other flows, for fairness convergence and lower RTTs and loss. ++ */ ++static u32 bbr_inflight_with_headroom(const struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 headroom, headroom_fraction; ++ ++ if (bbr->inflight_hi == ~0U) ++ return ~0U; ++ ++ headroom_fraction = bbr_param(sk, inflight_headroom); ++ headroom = ((u64)bbr->inflight_hi * headroom_fraction) >> BBR_SCALE; ++ headroom = max(headroom, 1U); ++ return max_t(s32, bbr->inflight_hi - headroom, ++ bbr_param(sk, cwnd_min_target)); ++} ++ ++/* Bound cwnd to a sensible level, based on our current probing state ++ * machine phase and model of a good inflight level (inflight_lo, inflight_hi). ++ */ ++static void bbr_bound_cwnd_for_inflight_model(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 cap; ++ ++ /* tcp_rcv_synsent_state_process() currently calls tcp_ack() ++ * and thus cong_control() without first initializing us(!). ++ */ ++ if (!bbr->initialized) ++ return; ++ ++ cap = ~0U; ++ if (bbr->mode == BBR_PROBE_BW && ++ bbr->cycle_idx != BBR_BW_PROBE_CRUISE) { ++ /* Probe to see if more packets fit in the path. */ ++ cap = bbr->inflight_hi; ++ } else { ++ if (bbr->mode == BBR_PROBE_RTT || ++ (bbr->mode == BBR_PROBE_BW && ++ bbr->cycle_idx == BBR_BW_PROBE_CRUISE)) ++ cap = bbr_inflight_with_headroom(sk); ++ } ++ /* Adapt to any loss/ECN since our last bw probe. */ ++ cap = min(cap, bbr->inflight_lo); ++ ++ cap = max_t(u32, cap, bbr_param(sk, cwnd_min_target)); ++ tcp_snd_cwnd_set(tp, min(cap, tcp_snd_cwnd(tp))); ++} ++ ++/* How should we multiplicatively cut bw or inflight limits based on ECN? */ ++static u32 bbr_ecn_cut(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ return BBR_UNIT - ++ ((bbr->ecn_alpha * bbr_param(sk, ecn_factor)) >> BBR_SCALE); ++} ++ ++/* Init lower bounds if have not inited yet. */ ++static void bbr_init_lower_bounds(struct sock *sk, bool init_bw) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ if (init_bw && bbr->bw_lo == ~0U) ++ bbr->bw_lo = bbr_max_bw(sk); ++ if (bbr->inflight_lo == ~0U) ++ bbr->inflight_lo = tcp_snd_cwnd(tp); ++} ++ ++/* Reduce bw and inflight to (1 - beta). */ ++static void bbr_loss_lower_bounds(struct sock *sk, u32 *bw, u32 *inflight) ++{ ++ struct bbr* bbr = inet_csk_ca(sk); ++ u32 loss_cut = BBR_UNIT - bbr_param(sk, beta); ++ ++ *bw = max_t(u32, bbr->bw_latest, ++ (u64)bbr->bw_lo * loss_cut >> BBR_SCALE); ++ *inflight = max_t(u32, bbr->inflight_latest, ++ (u64)bbr->inflight_lo * loss_cut >> BBR_SCALE); ++} ++ ++/* Reduce inflight to (1 - alpha*ecn_factor). */ ++static void bbr_ecn_lower_bounds(struct sock *sk, u32 *inflight) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 ecn_cut = bbr_ecn_cut(sk); ++ ++ *inflight = (u64)bbr->inflight_lo * ecn_cut >> BBR_SCALE; ++} ++ ++/* Estimate a short-term lower bound on the capacity available now, based ++ * on measurements of the current delivery process and recent history. When we ++ * are seeing loss/ECN at times when we are not probing bw, then conservatively ++ * move toward flow balance by multiplicatively cutting our short-term ++ * estimated safe rate and volume of data (bw_lo and inflight_lo). We use a ++ * multiplicative decrease in order to converge to a lower capacity in time ++ * logarithmic in the magnitude of the decrease. ++ * ++ * However, we do not cut our short-term estimates lower than the current rate ++ * and volume of delivered data from this round trip, since from the current ++ * delivery process we can estimate the measured capacity available now. ++ * ++ * Anything faster than that approach would knowingly risk high loss, which can ++ * cause low bw for Reno/CUBIC and high loss recovery latency for ++ * request/response flows using any congestion control. ++ */ ++static void bbr_adapt_lower_bounds(struct sock *sk, ++ const struct rate_sample *rs) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 ecn_inflight_lo = ~0U; ++ ++ /* We only use lower-bound estimates when not probing bw. ++ * When probing we need to push inflight higher to probe bw. ++ */ ++ if (bbr_is_probing_bandwidth(sk)) ++ return; ++ ++ /* ECN response. */ ++ if (bbr->ecn_in_round && bbr_param(sk, ecn_factor)) { ++ bbr_init_lower_bounds(sk, false); ++ bbr_ecn_lower_bounds(sk, &ecn_inflight_lo); ++ } ++ ++ /* Loss response. */ ++ if (bbr->loss_in_round) { ++ bbr_init_lower_bounds(sk, true); ++ bbr_loss_lower_bounds(sk, &bbr->bw_lo, &bbr->inflight_lo); ++ } ++ ++ /* Adjust to the lower of the levels implied by loss/ECN. */ ++ bbr->inflight_lo = min(bbr->inflight_lo, ecn_inflight_lo); ++ bbr->bw_lo = max(1U, bbr->bw_lo); ++} ++ ++/* Reset any short-term lower-bound adaptation to congestion, so that we can ++ * push our inflight up. ++ */ ++static void bbr_reset_lower_bounds(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr->bw_lo = ~0U; ++ bbr->inflight_lo = ~0U; ++} ++ ++/* After bw probing (STARTUP/PROBE_UP), reset signals before entering a state ++ * machine phase where we adapt our lower bound based on congestion signals. ++ */ ++static void bbr_reset_congestion_signals(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr->loss_in_round = 0; ++ bbr->ecn_in_round = 0; ++ bbr->loss_in_cycle = 0; ++ bbr->ecn_in_cycle = 0; ++ bbr->bw_latest = 0; ++ bbr->inflight_latest = 0; ++} ++ ++static void bbr_exit_loss_recovery(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); ++ bbr->try_fast_path = 0; /* bound cwnd using latest model */ ++} ++ ++/* Update rate and volume of delivered data from latest round trip. */ ++static void bbr_update_latest_delivery_signals( ++ struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr->loss_round_start = 0; ++ if (rs->interval_us <= 0 || !rs->acked_sacked) ++ return; /* Not a valid observation */ ++ ++ bbr->bw_latest = max_t(u32, bbr->bw_latest, ctx->sample_bw); ++ bbr->inflight_latest = max_t(u32, bbr->inflight_latest, rs->delivered); ++ ++ if (!before(rs->prior_delivered, bbr->loss_round_delivered)) { ++ bbr->loss_round_delivered = tp->delivered; ++ bbr->loss_round_start = 1; /* mark start of new round trip */ ++ } ++} ++ ++/* Once per round, reset filter for latest rate and volume of delivered data. */ ++static void bbr_advance_latest_delivery_signals( ++ struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ /* If ACK matches a TLP retransmit, persist the filter. If we detect ++ * that a TLP retransmit plugged a tail loss, we'll want to remember ++ * how much data the path delivered before the tail loss. ++ */ ++ if (bbr->loss_round_start && !rs->is_acking_tlp_retrans_seq) { ++ bbr->bw_latest = ctx->sample_bw; ++ bbr->inflight_latest = rs->delivered; ++ } ++} ++ ++/* Update (most of) our congestion signals: track the recent rate and volume of ++ * delivered data, presence of loss, and EWMA degree of ECN marking. ++ */ ++static void bbr_update_congestion_signals( ++ struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) + { + struct bbr *bbr = inet_csk_ca(sk); ++ u64 bw; ++ ++ if (rs->interval_us <= 0 || !rs->acked_sacked) ++ return; /* Not a valid observation */ ++ bw = ctx->sample_bw; + +- bbr->full_bw = 0; /* spurious slow-down; reset full pipe detection */ ++ if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) ++ bbr_take_max_bw_sample(sk, bw); ++ ++ bbr->loss_in_round |= (rs->losses > 0); ++ ++ if (!bbr->loss_round_start) ++ return; /* skip the per-round-trip updates */ ++ /* Now do per-round-trip updates. */ ++ bbr_adapt_lower_bounds(sk, rs); ++ ++ bbr->loss_in_round = 0; ++ bbr->ecn_in_round = 0; ++} ++ ++/* Bandwidth probing can cause loss. To help coexistence with loss-based ++ * congestion control we spread out our probing in a Reno-conscious way. Due to ++ * the shape of the Reno sawtooth, the time required between loss epochs for an ++ * idealized Reno flow is a number of round trips that is the BDP of that ++ * flow. We count packet-timed round trips directly, since measured RTT can ++ * vary widely, and Reno is driven by packet-timed round trips. ++ */ ++static bool bbr_is_reno_coexistence_probe_time(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 rounds; ++ ++ /* Random loss can shave some small percentage off of our inflight ++ * in each round. To survive this, flows need robust periodic probes. ++ */ ++ rounds = min_t(u32, bbr_param(sk, bw_probe_max_rounds), bbr_target_inflight(sk)); ++ return bbr->rounds_since_probe >= rounds; ++} ++ ++/* How long do we want to wait before probing for bandwidth (and risking ++ * loss)? We randomize the wait, for better mixing and fairness convergence. ++ * ++ * We bound the Reno-coexistence inter-bw-probe time to be 62-63 round trips. ++ * This is calculated to allow fairness with a 25Mbps, 30ms Reno flow, ++ * (eg 4K video to a broadband user): ++ * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets ++ * ++ * We bound the BBR-native inter-bw-probe wall clock time to be: ++ * (a) higher than 2 sec: to try to avoid causing loss for a long enough time ++ * to allow Reno at 30ms to get 4K video bw, the inter-bw-probe time must ++ * be at least: 25Mbps * .030sec / (1514bytes) * 0.030sec = 1.9secs ++ * (b) lower than 3 sec: to ensure flows can start probing in a reasonable ++ * amount of time to discover unutilized bw on human-scale interactive ++ * time-scales (e.g. perhaps traffic from a web page download that we ++ * were competing with is now complete). ++ */ ++static void bbr_pick_probe_wait(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ /* Decide the random round-trip bound for wait until probe: */ ++ bbr->rounds_since_probe = ++ get_random_u32_below(bbr_param(sk, bw_probe_rand_rounds)); ++ /* Decide the random wall clock bound for wait until probe: */ ++ bbr->probe_wait_us = bbr_param(sk, bw_probe_base_us) + ++ get_random_u32_below(bbr_param(sk, bw_probe_rand_us)); ++} ++ ++static void bbr_set_cycle_idx(struct sock *sk, int cycle_idx) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr->cycle_idx = cycle_idx; ++ /* New phase, so need to update cwnd and pacing rate. */ ++ bbr->try_fast_path = 0; ++} ++ ++/* Send at estimated bw to fill the pipe, but not queue. We need this phase ++ * before PROBE_UP, because as soon as we send faster than the available bw ++ * we will start building a queue, and if the buffer is shallow we can cause ++ * loss. If we do not fill the pipe before we cause this loss, our bw_hi and ++ * inflight_hi estimates will underestimate. ++ */ ++static void bbr_start_bw_probe_refill(struct sock *sk, u32 bw_probe_up_rounds) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr_reset_lower_bounds(sk); ++ bbr->bw_probe_up_rounds = bw_probe_up_rounds; ++ bbr->bw_probe_up_acks = 0; ++ bbr->stopped_risky_probe = 0; ++ bbr->ack_phase = BBR_ACKS_REFILLING; ++ bbr->next_rtt_delivered = tp->delivered; ++ bbr_set_cycle_idx(sk, BBR_BW_PROBE_REFILL); ++} ++ ++/* Now probe max deliverable data rate and volume. */ ++static void bbr_start_bw_probe_up(struct sock *sk, struct bbr_context *ctx) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr->ack_phase = BBR_ACKS_PROBE_STARTING; ++ bbr->next_rtt_delivered = tp->delivered; ++ bbr->cycle_mstamp = tp->tcp_mstamp; ++ bbr_reset_full_bw(sk); ++ bbr->full_bw = ctx->sample_bw; ++ bbr_set_cycle_idx(sk, BBR_BW_PROBE_UP); ++ bbr_raise_inflight_hi_slope(sk); ++} ++ ++/* Start a new PROBE_BW probing cycle of some wall clock length. Pick a wall ++ * clock time at which to probe beyond an inflight that we think to be ++ * safe. This will knowingly risk packet loss, so we want to do this rarely, to ++ * keep packet loss rates low. Also start a round-trip counter, to probe faster ++ * if we estimate a Reno flow at our BDP would probe faster. ++ */ ++static void bbr_start_bw_probe_down(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr_reset_congestion_signals(sk); ++ bbr->bw_probe_up_cnt = ~0U; /* not growing inflight_hi any more */ ++ bbr_pick_probe_wait(sk); ++ bbr->cycle_mstamp = tp->tcp_mstamp; /* start wall clock */ ++ bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; ++ bbr->next_rtt_delivered = tp->delivered; ++ bbr_set_cycle_idx(sk, BBR_BW_PROBE_DOWN); ++} ++ ++/* Cruise: maintain what we estimate to be a neutral, conservative ++ * operating point, without attempting to probe up for bandwidth or down for ++ * RTT, and only reducing inflight in response to loss/ECN signals. ++ */ ++static void bbr_start_bw_probe_cruise(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ if (bbr->inflight_lo != ~0U) ++ bbr->inflight_lo = min(bbr->inflight_lo, bbr->inflight_hi); ++ ++ bbr_set_cycle_idx(sk, BBR_BW_PROBE_CRUISE); ++} ++ ++/* Loss and/or ECN rate is too high while probing. ++ * Adapt (once per bw probe) by cutting inflight_hi and then restarting cycle. ++ */ ++static void bbr_handle_inflight_too_high(struct sock *sk, ++ const struct rate_sample *rs) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ const u32 beta = bbr_param(sk, beta); ++ ++ bbr->prev_probe_too_high = 1; ++ bbr->bw_probe_samples = 0; /* only react once per probe */ ++ /* If we are app-limited then we are not robustly ++ * probing the max volume of inflight data we think ++ * might be safe (analogous to how app-limited bw ++ * samples are not known to be robustly probing bw). ++ */ ++ if (!rs->is_app_limited) { ++ bbr->inflight_hi = max_t(u32, rs->tx_in_flight, ++ (u64)bbr_target_inflight(sk) * ++ (BBR_UNIT - beta) >> BBR_SCALE); ++ } ++ if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) ++ bbr_start_bw_probe_down(sk); ++} ++ ++/* If we're seeing bw and loss samples reflecting our bw probing, adapt ++ * using the signals we see. If loss or ECN mark rate gets too high, then adapt ++ * inflight_hi downward. If we're able to push inflight higher without such ++ * signals, push higher: adapt inflight_hi upward. ++ */ ++static bool bbr_adapt_upper_bounds(struct sock *sk, ++ const struct rate_sample *rs, ++ struct bbr_context *ctx) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ /* Track when we'll see bw/loss samples resulting from our bw probes. */ ++ if (bbr->ack_phase == BBR_ACKS_PROBE_STARTING && bbr->round_start) ++ bbr->ack_phase = BBR_ACKS_PROBE_FEEDBACK; ++ if (bbr->ack_phase == BBR_ACKS_PROBE_STOPPING && bbr->round_start) { ++ /* End of samples from bw probing phase. */ ++ bbr->bw_probe_samples = 0; ++ bbr->ack_phase = BBR_ACKS_INIT; ++ /* At this point in the cycle, our current bw sample is also ++ * our best recent chance at finding the highest available bw ++ * for this flow. So now is the best time to forget the bw ++ * samples from the previous cycle, by advancing the window. ++ */ ++ if (bbr->mode == BBR_PROBE_BW && !rs->is_app_limited) ++ bbr_advance_max_bw_filter(sk); ++ /* If we had an inflight_hi, then probed and pushed inflight all ++ * the way up to hit that inflight_hi without seeing any ++ * high loss/ECN in all the resulting ACKs from that probing, ++ * then probe up again, this time letting inflight persist at ++ * inflight_hi for a round trip, then accelerating beyond. ++ */ ++ if (bbr->mode == BBR_PROBE_BW && ++ bbr->stopped_risky_probe && !bbr->prev_probe_too_high) { ++ bbr_start_bw_probe_refill(sk, 0); ++ return true; /* yes, decided state transition */ ++ } ++ } ++ if (bbr_is_inflight_too_high(sk, rs)) { ++ if (bbr->bw_probe_samples) /* sample is from bw probing? */ ++ bbr_handle_inflight_too_high(sk, rs); ++ } else { ++ /* Loss/ECN rate is declared safe. Adjust upper bound upward. */ ++ ++ if (bbr->inflight_hi == ~0U) ++ return false; /* no excess queue signals yet */ ++ ++ /* To be resilient to random loss, we must raise bw/inflight_hi ++ * if we observe in any phase that a higher level is safe. ++ */ ++ if (rs->tx_in_flight > bbr->inflight_hi) { ++ bbr->inflight_hi = rs->tx_in_flight; ++ } ++ ++ if (bbr->mode == BBR_PROBE_BW && ++ bbr->cycle_idx == BBR_BW_PROBE_UP) ++ bbr_probe_inflight_hi_upward(sk, rs); ++ } ++ ++ return false; ++} ++ ++/* Check if it's time to probe for bandwidth now, and if so, kick it off. */ ++static bool bbr_check_time_to_probe_bw(struct sock *sk, ++ const struct rate_sample *rs) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 n; ++ ++ /* If we seem to be at an operating point where we are not seeing loss ++ * but we are seeing ECN marks, then when the ECN marks cease we reprobe ++ * quickly (in case cross-traffic has ceased and freed up bw). ++ */ ++ if (bbr_param(sk, ecn_reprobe_gain) && bbr->ecn_eligible && ++ bbr->ecn_in_cycle && !bbr->loss_in_cycle && ++ inet_csk(sk)->icsk_ca_state == TCP_CA_Open) { ++ /* Calculate n so that when bbr_raise_inflight_hi_slope() ++ * computes growth_this_round as 2^n it will be roughly the ++ * desired volume of data (inflight_hi*ecn_reprobe_gain). ++ */ ++ n = ilog2((((u64)bbr->inflight_hi * ++ bbr_param(sk, ecn_reprobe_gain)) >> BBR_SCALE)); ++ bbr_start_bw_probe_refill(sk, n); ++ return true; ++ } ++ ++ if (bbr_has_elapsed_in_phase(sk, bbr->probe_wait_us) || ++ bbr_is_reno_coexistence_probe_time(sk)) { ++ bbr_start_bw_probe_refill(sk, 0); ++ return true; ++ } ++ return false; ++} ++ ++/* Is it time to transition from PROBE_DOWN to PROBE_CRUISE? */ ++static bool bbr_check_time_to_cruise(struct sock *sk, u32 inflight, u32 bw) ++{ ++ /* Always need to pull inflight down to leave headroom in queue. */ ++ if (inflight > bbr_inflight_with_headroom(sk)) ++ return false; ++ ++ return inflight <= bbr_inflight(sk, bw, BBR_UNIT); ++} ++ ++/* PROBE_BW state machine: cruise, refill, probe for bw, or drain? */ ++static void bbr_update_cycle_phase(struct sock *sk, ++ const struct rate_sample *rs, ++ struct bbr_context *ctx) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ bool is_bw_probe_done = false; ++ u32 inflight, bw; ++ ++ if (!bbr_full_bw_reached(sk)) ++ return; ++ ++ /* In DRAIN, PROBE_BW, or PROBE_RTT, adjust upper bounds. */ ++ if (bbr_adapt_upper_bounds(sk, rs, ctx)) ++ return; /* already decided state transition */ ++ ++ if (bbr->mode != BBR_PROBE_BW) ++ return; ++ ++ inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); ++ bw = bbr_max_bw(sk); ++ ++ switch (bbr->cycle_idx) { ++ /* First we spend most of our time cruising with a pacing_gain of 1.0, ++ * which paces at the estimated bw, to try to fully use the pipe ++ * without building queue. If we encounter loss/ECN marks, we adapt ++ * by slowing down. ++ */ ++ case BBR_BW_PROBE_CRUISE: ++ if (bbr_check_time_to_probe_bw(sk, rs)) ++ return; /* already decided state transition */ ++ break; ++ ++ /* After cruising, when it's time to probe, we first "refill": we send ++ * at the estimated bw to fill the pipe, before probing higher and ++ * knowingly risking overflowing the bottleneck buffer (causing loss). ++ */ ++ case BBR_BW_PROBE_REFILL: ++ if (bbr->round_start) { ++ /* After one full round trip of sending in REFILL, we ++ * start to see bw samples reflecting our REFILL, which ++ * may be putting too much data in flight. ++ */ ++ bbr->bw_probe_samples = 1; ++ bbr_start_bw_probe_up(sk, ctx); ++ } ++ break; ++ ++ /* After we refill the pipe, we probe by using a pacing_gain > 1.0, to ++ * probe for bw. If we have not seen loss/ECN, we try to raise inflight ++ * to at least pacing_gain*BDP; note that this may take more than ++ * min_rtt if min_rtt is small (e.g. on a LAN). ++ * ++ * We terminate PROBE_UP bandwidth probing upon any of the following: ++ * ++ * (1) We've pushed inflight up to hit the inflight_hi target set in the ++ * most recent previous bw probe phase. Thus we want to start ++ * draining the queue immediately because it's very likely the most ++ * recently sent packets will fill the queue and cause drops. ++ * (2) If inflight_hi has not limited bandwidth growth recently, and ++ * yet delivered bandwidth has not increased much recently ++ * (bbr->full_bw_now). ++ * (3) Loss filter says loss rate is "too high". ++ * (4) ECN filter says ECN mark rate is "too high". ++ * ++ * (1) (2) checked here, (3) (4) checked in bbr_is_inflight_too_high() ++ */ ++ case BBR_BW_PROBE_UP: ++ if (bbr->prev_probe_too_high && ++ inflight >= bbr->inflight_hi) { ++ bbr->stopped_risky_probe = 1; ++ is_bw_probe_done = true; ++ } else { ++ if (tp->is_cwnd_limited && ++ tcp_snd_cwnd(tp) >= bbr->inflight_hi) { ++ /* inflight_hi is limiting bw growth */ ++ bbr_reset_full_bw(sk); ++ bbr->full_bw = ctx->sample_bw; ++ } else if (bbr->full_bw_now) { ++ /* Plateau in estimated bw. Pipe looks full. */ ++ is_bw_probe_done = true; ++ } ++ } ++ if (is_bw_probe_done) { ++ bbr->prev_probe_too_high = 0; /* no loss/ECN (yet) */ ++ bbr_start_bw_probe_down(sk); /* restart w/ down */ ++ } ++ break; ++ ++ /* After probing in PROBE_UP, we have usually accumulated some data in ++ * the bottleneck buffer (if bw probing didn't find more bw). We next ++ * enter PROBE_DOWN to try to drain any excess data from the queue. To ++ * do this, we use a pacing_gain < 1.0. We hold this pacing gain until ++ * our inflight is less then that target cruising point, which is the ++ * minimum of (a) the amount needed to leave headroom, and (b) the ++ * estimated BDP. Once inflight falls to match the target, we estimate ++ * the queue is drained; persisting would underutilize the pipe. ++ */ ++ case BBR_BW_PROBE_DOWN: ++ if (bbr_check_time_to_probe_bw(sk, rs)) ++ return; /* already decided state transition */ ++ if (bbr_check_time_to_cruise(sk, inflight, bw)) ++ bbr_start_bw_probe_cruise(sk); ++ break; ++ ++ default: ++ WARN_ONCE(1, "BBR invalid cycle index %u\n", bbr->cycle_idx); ++ } ++} ++ ++/* Exiting PROBE_RTT, so return to bandwidth probing in STARTUP or PROBE_BW. */ ++static void bbr_exit_probe_rtt(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr_reset_lower_bounds(sk); ++ if (bbr_full_bw_reached(sk)) { ++ bbr->mode = BBR_PROBE_BW; ++ /* Raising inflight after PROBE_RTT may cause loss, so reset ++ * the PROBE_BW clock and schedule the next bandwidth probe for ++ * a friendly and randomized future point in time. ++ */ ++ bbr_start_bw_probe_down(sk); ++ /* Since we are exiting PROBE_RTT, we know inflight is ++ * below our estimated BDP, so it is reasonable to cruise. ++ */ ++ bbr_start_bw_probe_cruise(sk); ++ } else { ++ bbr->mode = BBR_STARTUP; ++ } ++} ++ ++/* Exit STARTUP based on loss rate > 1% and loss gaps in round >= N. Wait until ++ * the end of the round in recovery to get a good estimate of how many packets ++ * have been lost, and how many we need to drain with a low pacing rate. ++ */ ++static void bbr_check_loss_too_high_in_startup(struct sock *sk, ++ const struct rate_sample *rs) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ if (bbr_full_bw_reached(sk)) ++ return; ++ ++ /* For STARTUP exit, check the loss rate at the end of each round trip ++ * of Recovery episodes in STARTUP. We check the loss rate at the end ++ * of the round trip to filter out noisy/low loss and have a better ++ * sense of inflight (extent of loss), so we can drain more accurately. ++ */ ++ if (rs->losses && bbr->loss_events_in_round < 0xf) ++ bbr->loss_events_in_round++; /* update saturating counter */ ++ if (bbr_param(sk, full_loss_cnt) && bbr->loss_round_start && ++ inet_csk(sk)->icsk_ca_state == TCP_CA_Recovery && ++ bbr->loss_events_in_round >= bbr_param(sk, full_loss_cnt) && ++ bbr_is_inflight_too_high(sk, rs)) { ++ bbr_handle_queue_too_high_in_startup(sk); ++ return; ++ } ++ if (bbr->loss_round_start) ++ bbr->loss_events_in_round = 0; ++} ++ ++/* Estimate when the pipe is full, using the change in delivery rate: BBR ++ * estimates bw probing filled the pipe if the estimated bw hasn't changed by ++ * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited ++ * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the ++ * higher rwin, 3: we get higher delivery rate samples. Or transient ++ * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar ++ * design goal, but uses delay and inter-ACK spacing instead of bandwidth. ++ */ ++static void bbr_check_full_bw_reached(struct sock *sk, ++ const struct rate_sample *rs, ++ struct bbr_context *ctx) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 bw_thresh, full_cnt, thresh; ++ ++ if (bbr->full_bw_now || rs->is_app_limited) ++ return; ++ ++ thresh = bbr_param(sk, full_bw_thresh); ++ full_cnt = bbr_param(sk, full_bw_cnt); ++ bw_thresh = (u64)bbr->full_bw * thresh >> BBR_SCALE; ++ if (ctx->sample_bw >= bw_thresh) { ++ bbr_reset_full_bw(sk); ++ bbr->full_bw = ctx->sample_bw; ++ return; ++ } ++ if (!bbr->round_start) ++ return; ++ ++bbr->full_bw_cnt; ++ bbr->full_bw_now = bbr->full_bw_cnt >= full_cnt; ++ bbr->full_bw_reached |= bbr->full_bw_now; ++} ++ ++/* If pipe is probably full, drain the queue and then enter steady-state. */ ++static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs, ++ struct bbr_context *ctx) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { ++ bbr->mode = BBR_DRAIN; /* drain queue we created */ ++ /* Set ssthresh to export purely for monitoring, to signal ++ * completion of initial STARTUP by setting to a non- ++ * TCP_INFINITE_SSTHRESH value (ssthresh is not used by BBR). ++ */ ++ tcp_sk(sk)->snd_ssthresh = ++ bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); ++ bbr_reset_congestion_signals(sk); ++ } /* fall through to check if in-flight is already small: */ ++ if (bbr->mode == BBR_DRAIN && ++ bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= ++ bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) { ++ bbr->mode = BBR_PROBE_BW; ++ bbr_start_bw_probe_down(sk); ++ } ++} ++ ++static void bbr_update_model(struct sock *sk, const struct rate_sample *rs, ++ struct bbr_context *ctx) ++{ ++ bbr_update_congestion_signals(sk, rs, ctx); ++ bbr_update_ack_aggregation(sk, rs); ++ bbr_check_loss_too_high_in_startup(sk, rs); ++ bbr_check_full_bw_reached(sk, rs, ctx); ++ bbr_check_drain(sk, rs, ctx); ++ bbr_update_cycle_phase(sk, rs, ctx); ++ bbr_update_min_rtt(sk, rs); ++} ++ ++/* Fast path for app-limited case. ++ * ++ * On each ack, we execute bbr state machine, which primarily consists of: ++ * 1) update model based on new rate sample, and ++ * 2) update control based on updated model or state change. ++ * ++ * There are certain workload/scenarios, e.g. app-limited case, where ++ * either we can skip updating model or we can skip update of both model ++ * as well as control. This provides signifcant softirq cpu savings for ++ * processing incoming acks. ++ * ++ * In case of app-limited, if there is no congestion (loss/ecn) and ++ * if observed bw sample is less than current estimated bw, then we can ++ * skip some of the computation in bbr state processing: ++ * ++ * - if there is no rtt/mode/phase change: In this case, since all the ++ * parameters of the network model are constant, we can skip model ++ * as well control update. ++ * ++ * - else we can skip rest of the model update. But we still need to ++ * update the control to account for the new rtt/mode/phase. ++ * ++ * Returns whether we can take fast path or not. ++ */ ++static bool bbr_run_fast_path(struct sock *sk, bool *update_model, ++ const struct rate_sample *rs, struct bbr_context *ctx) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ u32 prev_min_rtt_us, prev_mode; ++ ++ if (bbr_param(sk, fast_path) && bbr->try_fast_path && ++ rs->is_app_limited && ctx->sample_bw < bbr_max_bw(sk) && ++ !bbr->loss_in_round && !bbr->ecn_in_round ) { ++ prev_mode = bbr->mode; ++ prev_min_rtt_us = bbr->min_rtt_us; ++ bbr_check_drain(sk, rs, ctx); ++ bbr_update_cycle_phase(sk, rs, ctx); ++ bbr_update_min_rtt(sk, rs); ++ ++ if (bbr->mode == prev_mode && ++ bbr->min_rtt_us == prev_min_rtt_us && ++ bbr->try_fast_path) { ++ return true; ++ } ++ ++ /* Skip model update, but control still needs to be updated */ ++ *update_model = false; ++ } ++ return false; ++} ++ ++__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ struct bbr_context ctx = { 0 }; ++ bool update_model = true; ++ u32 bw, round_delivered; ++ int ce_ratio = -1; ++ ++ round_delivered = bbr_update_round_start(sk, rs, &ctx); ++ if (bbr->round_start) { ++ bbr->rounds_since_probe = ++ min_t(s32, bbr->rounds_since_probe + 1, 0xFF); ++ ce_ratio = bbr_update_ecn_alpha(sk); ++ } ++ bbr_plb(sk, rs, ce_ratio); ++ ++ bbr->ecn_in_round |= (bbr->ecn_eligible && rs->is_ece); ++ bbr_calculate_bw_sample(sk, rs, &ctx); ++ bbr_update_latest_delivery_signals(sk, rs, &ctx); ++ ++ if (bbr_run_fast_path(sk, &update_model, rs, &ctx)) ++ goto out; ++ ++ if (update_model) ++ bbr_update_model(sk, rs, &ctx); ++ ++ bbr_update_gains(sk); ++ bw = bbr_bw(sk); ++ bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); ++ bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain, ++ tcp_snd_cwnd(tp), &ctx); ++ bbr_bound_cwnd_for_inflight_model(sk); ++ ++out: ++ bbr_advance_latest_delivery_signals(sk, rs, &ctx); ++ bbr->prev_ca_state = inet_csk(sk)->icsk_ca_state; ++ bbr->loss_in_cycle |= rs->lost > 0; ++ bbr->ecn_in_cycle |= rs->delivered_ce > 0; ++} ++ ++__bpf_kfunc static void bbr_init(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr->initialized = 1; ++ ++ bbr->init_cwnd = min(0x7FU, tcp_snd_cwnd(tp)); ++ bbr->prior_cwnd = tp->prior_cwnd; ++ tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; ++ bbr->next_rtt_delivered = tp->delivered; ++ bbr->prev_ca_state = TCP_CA_Open; ++ ++ bbr->probe_rtt_done_stamp = 0; ++ bbr->probe_rtt_round_done = 0; ++ bbr->probe_rtt_min_us = tcp_min_rtt(tp); ++ bbr->probe_rtt_min_stamp = tcp_jiffies32; ++ bbr->min_rtt_us = tcp_min_rtt(tp); ++ bbr->min_rtt_stamp = tcp_jiffies32; ++ ++ bbr->has_seen_rtt = 0; ++ bbr_init_pacing_rate_from_rtt(sk); ++ ++ bbr->round_start = 0; ++ bbr->idle_restart = 0; ++ bbr->full_bw_reached = 0; ++ bbr->full_bw = 0; + bbr->full_bw_cnt = 0; +- bbr_reset_lt_bw_sampling(sk); +- return tcp_snd_cwnd(tcp_sk(sk)); ++ bbr->cycle_mstamp = 0; ++ bbr->cycle_idx = 0; ++ ++ bbr_reset_startup_mode(sk); ++ ++ bbr->ack_epoch_mstamp = tp->tcp_mstamp; ++ bbr->ack_epoch_acked = 0; ++ bbr->extra_acked_win_rtts = 0; ++ bbr->extra_acked_win_idx = 0; ++ bbr->extra_acked[0] = 0; ++ bbr->extra_acked[1] = 0; ++ ++ bbr->ce_state = 0; ++ bbr->prior_rcv_nxt = tp->rcv_nxt; ++ bbr->try_fast_path = 0; ++ ++ cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); ++ ++ /* Start sampling ECN mark rate after first full flight is ACKed: */ ++ bbr->loss_round_delivered = tp->delivered + 1; ++ bbr->loss_round_start = 0; ++ bbr->undo_bw_lo = 0; ++ bbr->undo_inflight_lo = 0; ++ bbr->undo_inflight_hi = 0; ++ bbr->loss_events_in_round = 0; ++ bbr->startup_ecn_rounds = 0; ++ bbr_reset_congestion_signals(sk); ++ bbr->bw_lo = ~0U; ++ bbr->bw_hi[0] = 0; ++ bbr->bw_hi[1] = 0; ++ bbr->inflight_lo = ~0U; ++ bbr->inflight_hi = ~0U; ++ bbr_reset_full_bw(sk); ++ bbr->bw_probe_up_cnt = ~0U; ++ bbr->bw_probe_up_acks = 0; ++ bbr->bw_probe_up_rounds = 0; ++ bbr->probe_wait_us = 0; ++ bbr->stopped_risky_probe = 0; ++ bbr->ack_phase = BBR_ACKS_INIT; ++ bbr->rounds_since_probe = 0; ++ bbr->bw_probe_samples = 0; ++ bbr->prev_probe_too_high = 0; ++ bbr->ecn_eligible = 0; ++ bbr->ecn_alpha = bbr_param(sk, ecn_alpha_init); ++ bbr->alpha_last_delivered = 0; ++ bbr->alpha_last_delivered_ce = 0; ++ bbr->plb.pause_until = 0; ++ ++ tp->fast_ack_mode = bbr_fast_ack_mode ? 1 : 0; ++ ++ if (bbr_can_use_ecn(sk)) ++ tp->ecn_flags |= TCP_ECN_ECT_PERMANENT; ++} ++ ++/* BBR marks the current round trip as a loss round. */ ++static void bbr_note_loss(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ /* Capture "current" data over the full round trip of loss, to ++ * have a better chance of observing the full capacity of the path. ++ */ ++ if (!bbr->loss_in_round) /* first loss in this round trip? */ ++ bbr->loss_round_delivered = tp->delivered; /* set round trip */ ++ bbr->loss_in_round = 1; ++ bbr->loss_in_cycle = 1; + } + +-/* Entering loss recovery, so save cwnd for when we exit or undo recovery. */ ++/* Core TCP stack informs us that the given skb was just marked lost. */ ++__bpf_kfunc static void bbr_skb_marked_lost(struct sock *sk, ++ const struct sk_buff *skb) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ struct tcp_skb_cb *scb = TCP_SKB_CB(skb); ++ struct rate_sample rs = {}; ++ ++ bbr_note_loss(sk); ++ ++ if (!bbr->bw_probe_samples) ++ return; /* not an skb sent while probing for bandwidth */ ++ if (unlikely(!scb->tx.delivered_mstamp)) ++ return; /* skb was SACKed, reneged, marked lost; ignore it */ ++ /* We are probing for bandwidth. Construct a rate sample that ++ * estimates what happened in the flight leading up to this lost skb, ++ * then see if the loss rate went too high, and if so at which packet. ++ */ ++ rs.tx_in_flight = scb->tx.in_flight; ++ rs.lost = tp->lost - scb->tx.lost; ++ rs.is_app_limited = scb->tx.is_app_limited; ++ if (bbr_is_inflight_too_high(sk, &rs)) { ++ rs.tx_in_flight = bbr_inflight_hi_from_lost_skb(sk, &rs, skb); ++ bbr_handle_inflight_too_high(sk, &rs); ++ } ++} ++ ++static void bbr_run_loss_probe_recovery(struct sock *sk) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ struct bbr *bbr = inet_csk_ca(sk); ++ struct rate_sample rs = {0}; ++ ++ bbr_note_loss(sk); ++ ++ if (!bbr->bw_probe_samples) ++ return; /* not sent while probing for bandwidth */ ++ /* We are probing for bandwidth. Construct a rate sample that ++ * estimates what happened in the flight leading up to this ++ * loss, then see if the loss rate went too high. ++ */ ++ rs.lost = 1; /* TLP probe repaired loss of a single segment */ ++ rs.tx_in_flight = bbr->inflight_latest + rs.lost; ++ rs.is_app_limited = tp->tlp_orig_data_app_limited; ++ if (bbr_is_inflight_too_high(sk, &rs)) ++ bbr_handle_inflight_too_high(sk, &rs); ++} ++ ++/* Revert short-term model if current loss recovery event was spurious. */ ++__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) ++{ ++ struct bbr *bbr = inet_csk_ca(sk); ++ ++ bbr_reset_full_bw(sk); /* spurious slow-down; reset full bw detector */ ++ bbr->loss_in_round = 0; ++ ++ /* Revert to cwnd and other state saved before loss episode. */ ++ bbr->bw_lo = max(bbr->bw_lo, bbr->undo_bw_lo); ++ bbr->inflight_lo = max(bbr->inflight_lo, bbr->undo_inflight_lo); ++ bbr->inflight_hi = max(bbr->inflight_hi, bbr->undo_inflight_hi); ++ bbr->try_fast_path = 0; /* take slow path to set proper cwnd, pacing */ ++ return bbr->prior_cwnd; ++} ++ ++/* Entering loss recovery, so save state for when we undo recovery. */ + __bpf_kfunc static u32 bbr_ssthresh(struct sock *sk) + { ++ struct bbr *bbr = inet_csk_ca(sk); ++ + bbr_save_cwnd(sk); ++ /* For undo, save state that adapts based on loss signal. */ ++ bbr->undo_bw_lo = bbr->bw_lo; ++ bbr->undo_inflight_lo = bbr->inflight_lo; ++ bbr->undo_inflight_hi = bbr->inflight_hi; + return tcp_sk(sk)->snd_ssthresh; + } + ++static enum tcp_bbr_phase bbr_get_phase(struct bbr *bbr) ++{ ++ switch (bbr->mode) { ++ case BBR_STARTUP: ++ return BBR_PHASE_STARTUP; ++ case BBR_DRAIN: ++ return BBR_PHASE_DRAIN; ++ case BBR_PROBE_BW: ++ break; ++ case BBR_PROBE_RTT: ++ return BBR_PHASE_PROBE_RTT; ++ default: ++ return BBR_PHASE_INVALID; ++ } ++ switch (bbr->cycle_idx) { ++ case BBR_BW_PROBE_UP: ++ return BBR_PHASE_PROBE_BW_UP; ++ case BBR_BW_PROBE_DOWN: ++ return BBR_PHASE_PROBE_BW_DOWN; ++ case BBR_BW_PROBE_CRUISE: ++ return BBR_PHASE_PROBE_BW_CRUISE; ++ case BBR_BW_PROBE_REFILL: ++ return BBR_PHASE_PROBE_BW_REFILL; ++ default: ++ return BBR_PHASE_INVALID; ++ } ++} ++ + static size_t bbr_get_info(struct sock *sk, u32 ext, int *attr, +- union tcp_cc_info *info) ++ union tcp_cc_info *info) + { + if (ext & (1 << (INET_DIAG_BBRINFO - 1)) || + ext & (1 << (INET_DIAG_VEGASINFO - 1))) { +- struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); +- u64 bw = bbr_bw(sk); +- +- bw = bw * tp->mss_cache * USEC_PER_SEC >> BW_SCALE; +- memset(&info->bbr, 0, sizeof(info->bbr)); +- info->bbr.bbr_bw_lo = (u32)bw; +- info->bbr.bbr_bw_hi = (u32)(bw >> 32); +- info->bbr.bbr_min_rtt = bbr->min_rtt_us; +- info->bbr.bbr_pacing_gain = bbr->pacing_gain; +- info->bbr.bbr_cwnd_gain = bbr->cwnd_gain; ++ u64 bw = bbr_bw_bytes_per_sec(sk, bbr_bw(sk)); ++ u64 bw_hi = bbr_bw_bytes_per_sec(sk, bbr_max_bw(sk)); ++ u64 bw_lo = bbr->bw_lo == ~0U ? ++ ~0ULL : bbr_bw_bytes_per_sec(sk, bbr->bw_lo); ++ struct tcp_bbr_info *bbr_info = &info->bbr; ++ ++ memset(bbr_info, 0, sizeof(*bbr_info)); ++ bbr_info->bbr_bw_lo = (u32)bw; ++ bbr_info->bbr_bw_hi = (u32)(bw >> 32); ++ bbr_info->bbr_min_rtt = bbr->min_rtt_us; ++ bbr_info->bbr_pacing_gain = bbr->pacing_gain; ++ bbr_info->bbr_cwnd_gain = bbr->cwnd_gain; ++ bbr_info->bbr_bw_hi_lsb = (u32)bw_hi; ++ bbr_info->bbr_bw_hi_msb = (u32)(bw_hi >> 32); ++ bbr_info->bbr_bw_lo_lsb = (u32)bw_lo; ++ bbr_info->bbr_bw_lo_msb = (u32)(bw_lo >> 32); ++ bbr_info->bbr_mode = bbr->mode; ++ bbr_info->bbr_phase = (__u8)bbr_get_phase(bbr); ++ bbr_info->bbr_version = (__u8)BBR_VERSION; ++ bbr_info->bbr_inflight_lo = bbr->inflight_lo; ++ bbr_info->bbr_inflight_hi = bbr->inflight_hi; ++ bbr_info->bbr_extra_acked = bbr_extra_acked(sk); + *attr = INET_DIAG_BBRINFO; +- return sizeof(info->bbr); ++ return sizeof(*bbr_info); + } + return 0; + } + + __bpf_kfunc static void bbr_set_state(struct sock *sk, u8 new_state) + { ++ struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (new_state == TCP_CA_Loss) { +- struct rate_sample rs = { .losses = 1 }; + + bbr->prev_ca_state = TCP_CA_Loss; +- bbr->full_bw = 0; +- bbr->round_start = 1; /* treat RTO like end of a round */ +- bbr_lt_bw_sampling(sk, &rs); ++ tcp_plb_update_state_upon_rto(sk, &bbr->plb); ++ /* The tcp_write_timeout() call to sk_rethink_txhash() likely ++ * repathed this flow, so re-learn the min network RTT on the ++ * new path: ++ */ ++ bbr_reset_full_bw(sk); ++ if (!bbr_is_probing_bandwidth(sk) && bbr->inflight_lo == ~0U) { ++ /* bbr_adapt_lower_bounds() needs cwnd before ++ * we suffered an RTO, to update inflight_lo: ++ */ ++ bbr->inflight_lo = ++ max(tcp_snd_cwnd(tp), bbr->prior_cwnd); ++ } ++ } else if (bbr->prev_ca_state == TCP_CA_Loss && ++ new_state != TCP_CA_Loss) { ++ bbr_exit_loss_recovery(sk); + } + } + ++ + static struct tcp_congestion_ops tcp_bbr_cong_ops __read_mostly = { +- .flags = TCP_CONG_NON_RESTRICTED, ++ .flags = TCP_CONG_NON_RESTRICTED | TCP_CONG_WANTS_CE_EVENTS, + .name = "bbr", + .owner = THIS_MODULE, + .init = bbr_init, + .cong_control = bbr_main, + .sndbuf_expand = bbr_sndbuf_expand, ++ .skb_marked_lost = bbr_skb_marked_lost, + .undo_cwnd = bbr_undo_cwnd, + .cwnd_event = bbr_cwnd_event, + .ssthresh = bbr_ssthresh, +- .min_tso_segs = bbr_min_tso_segs, ++ .tso_segs = bbr_tso_segs, + .get_info = bbr_get_info, + .set_state = bbr_set_state, + }; +@@ -1159,10 +2359,11 @@ BTF_KFUNCS_START(tcp_bbr_check_kfunc_ids) + BTF_ID_FLAGS(func, bbr_init) + BTF_ID_FLAGS(func, bbr_main) + BTF_ID_FLAGS(func, bbr_sndbuf_expand) ++BTF_ID_FLAGS(func, bbr_skb_marked_lost) + BTF_ID_FLAGS(func, bbr_undo_cwnd) + BTF_ID_FLAGS(func, bbr_cwnd_event) + BTF_ID_FLAGS(func, bbr_ssthresh) +-BTF_ID_FLAGS(func, bbr_min_tso_segs) ++BTF_ID_FLAGS(func, bbr_tso_segs) + BTF_ID_FLAGS(func, bbr_set_state) + BTF_KFUNCS_END(tcp_bbr_check_kfunc_ids) + +@@ -1195,5 +2396,12 @@ MODULE_AUTHOR("Van Jacobson "); + MODULE_AUTHOR("Neal Cardwell "); + MODULE_AUTHOR("Yuchung Cheng "); + MODULE_AUTHOR("Soheil Hassas Yeganeh "); ++MODULE_AUTHOR("Priyaranjan Jha "); ++MODULE_AUTHOR("Yousuk Seung "); ++MODULE_AUTHOR("Kevin Yang "); ++MODULE_AUTHOR("Arjun Roy "); ++MODULE_AUTHOR("David Morley "); ++ + MODULE_LICENSE("Dual BSD/GPL"); + MODULE_DESCRIPTION("TCP BBR (Bottleneck Bandwidth and RTT)"); ++MODULE_VERSION(__stringify(BBR_VERSION)); +diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c +index 28ffcfbeef14..7b13915ba288 100644 +--- a/net/ipv4/tcp_cong.c ++++ b/net/ipv4/tcp_cong.c +@@ -237,6 +237,7 @@ void tcp_init_congestion_control(struct sock *sk) + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_sk(sk)->prior_ssthresh = 0; ++ tcp_sk(sk)->fast_ack_mode = 0; + if (icsk->icsk_ca_ops->init) + icsk->icsk_ca_ops->init(sk); + if (tcp_ca_needs_ecn(sk)) +diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c +index ecd521108559..83b4928bc014 100644 +--- a/net/ipv4/tcp_input.c ++++ b/net/ipv4/tcp_input.c +@@ -365,7 +365,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) + tcp_enter_quickack_mode(sk, 2); + break; + case INET_ECN_CE: +- if (tcp_ca_needs_ecn(sk)) ++ if (tcp_ca_wants_ce_events(sk)) + tcp_ca_event(sk, CA_EVENT_ECN_IS_CE); + + if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR)) { +@@ -376,7 +376,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) + tp->ecn_flags |= TCP_ECN_SEEN; + break; + default: +- if (tcp_ca_needs_ecn(sk)) ++ if (tcp_ca_wants_ce_events(sk)) + tcp_ca_event(sk, CA_EVENT_ECN_NO_CE); + tp->ecn_flags |= TCP_ECN_SEEN; + break; +@@ -1124,7 +1124,12 @@ static void tcp_verify_retransmit_hint(struct tcp_sock *tp, struct sk_buff *skb) + */ + static void tcp_notify_skb_loss_event(struct tcp_sock *tp, const struct sk_buff *skb) + { ++ struct sock *sk = (struct sock *)tp; ++ const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; ++ + tp->lost += tcp_skb_pcount(skb); ++ if (ca_ops->skb_marked_lost) ++ ca_ops->skb_marked_lost(sk, skb); + } + + void tcp_mark_skb_lost(struct sock *sk, struct sk_buff *skb) +@@ -1505,6 +1510,17 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, + WARN_ON_ONCE(tcp_skb_pcount(skb) < pcount); + tcp_skb_pcount_add(skb, -pcount); + ++ /* Adjust tx.in_flight as pcount is shifted from skb to prev. */ ++ if (WARN_ONCE(TCP_SKB_CB(skb)->tx.in_flight < pcount, ++ "prev in_flight: %u skb in_flight: %u pcount: %u", ++ TCP_SKB_CB(prev)->tx.in_flight, ++ TCP_SKB_CB(skb)->tx.in_flight, ++ pcount)) ++ TCP_SKB_CB(skb)->tx.in_flight = 0; ++ else ++ TCP_SKB_CB(skb)->tx.in_flight -= pcount; ++ TCP_SKB_CB(prev)->tx.in_flight += pcount; ++ + /* When we're adding to gso_segs == 1, gso_size will be zero, + * in theory this shouldn't be necessary but as long as DSACK + * code can come after this skb later on it's better to keep +@@ -3799,7 +3815,8 @@ static void tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) + /* This routine deals with acks during a TLP episode and ends an episode by + * resetting tlp_high_seq. Ref: TLP algorithm in draft-ietf-tcpm-rack + */ +-static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) ++static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag, ++ struct rate_sample *rs) + { + struct tcp_sock *tp = tcp_sk(sk); + +@@ -3816,6 +3833,7 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) + /* ACK advances: there was a loss, so reduce cwnd. Reset + * tlp_high_seq in tcp_init_cwnd_reduction() + */ ++ tcp_ca_event(sk, CA_EVENT_TLP_RECOVERY); + tcp_init_cwnd_reduction(sk); + tcp_set_ca_state(sk, TCP_CA_CWR); + tcp_end_cwnd_reduction(sk); +@@ -3826,6 +3844,11 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) + FLAG_NOT_DUP | FLAG_DATA_SACKED))) { + /* Pure dupack: original and TLP probe arrived; no loss */ + tp->tlp_high_seq = 0; ++ } else { ++ /* This ACK matches a TLP retransmit. We cannot yet tell if ++ * this ACK is for the original or the TLP retransmit. ++ */ ++ rs->is_acking_tlp_retrans_seq = 1; + } + } + +@@ -3934,6 +3957,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) + + prior_fack = tcp_is_sack(tp) ? tcp_highest_sack_seq(tp) : tp->snd_una; + rs.prior_in_flight = tcp_packets_in_flight(tp); ++ tcp_rate_check_app_limited(sk); + + /* ts_recent update must be made after we are sure that the packet + * is in window. +@@ -4008,7 +4032,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) + tcp_rack_update_reo_wnd(sk, &rs); + + if (tp->tlp_high_seq) +- tcp_process_tlp_ack(sk, ack, flag); ++ tcp_process_tlp_ack(sk, ack, flag, &rs); + + if (tcp_ack_is_dubious(sk, flag)) { + if (!(flag & (FLAG_SND_UNA_ADVANCED | +@@ -4032,6 +4056,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) + delivered = tcp_newly_delivered(sk, delivered, flag); + lost = tp->lost - lost; /* freshly marked lost */ + rs.is_ack_delayed = !!(flag & FLAG_ACK_MAYBE_DELAYED); ++ rs.is_ece = !!(flag & FLAG_ECE); + tcp_rate_gen(sk, delivered, lost, is_sack_reneg, sack_state.rate); + tcp_cong_control(sk, ack, delivered, flag, sack_state.rate); + tcp_xmit_recovery(sk, rexmit); +@@ -4051,7 +4076,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) + tcp_ack_probe(sk); + + if (tp->tlp_high_seq) +- tcp_process_tlp_ack(sk, ack, flag); ++ tcp_process_tlp_ack(sk, ack, flag, &rs); + return 1; + + old_ack: +@@ -5723,13 +5748,14 @@ static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) + + /* More than one full frame received... */ + if (((tp->rcv_nxt - tp->rcv_wup) > inet_csk(sk)->icsk_ack.rcv_mss && ++ (tp->fast_ack_mode == 1 || + /* ... and right edge of window advances far enough. + * (tcp_recvmsg() will send ACK otherwise). + * If application uses SO_RCVLOWAT, we want send ack now if + * we have not received enough bytes to satisfy the condition. + */ +- (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || +- __tcp_select_window(sk) >= tp->rcv_wnd)) || ++ (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || ++ __tcp_select_window(sk) >= tp->rcv_wnd))) || + /* We ACK each frame or... */ + tcp_in_quickack_mode(sk) || + /* Protocol state mandates a one-time immediate ACK */ +diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c +index 0fbebf6266e9..6eb1d369c584 100644 +--- a/net/ipv4/tcp_minisocks.c ++++ b/net/ipv4/tcp_minisocks.c +@@ -460,6 +460,8 @@ void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst) + u32 ca_key = dst_metric(dst, RTAX_CC_ALGO); + bool ca_got_dst = false; + ++ tcp_set_ecn_low_from_dst(sk, dst); ++ + if (ca_key != TCP_CA_UNSPEC) { + const struct tcp_congestion_ops *ca; + +diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c +index 95618d0e78e4..3f4bdd2b6476 100644 +--- a/net/ipv4/tcp_output.c ++++ b/net/ipv4/tcp_output.c +@@ -336,10 +336,9 @@ static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) + bool bpf_needs_ecn = tcp_bpf_ca_needs_ecn(sk); + bool use_ecn = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn) == 1 || + tcp_ca_needs_ecn(sk) || bpf_needs_ecn; ++ const struct dst_entry *dst = __sk_dst_get(sk); + + if (!use_ecn) { +- const struct dst_entry *dst = __sk_dst_get(sk); +- + if (dst && dst_feature(dst, RTAX_FEATURE_ECN)) + use_ecn = true; + } +@@ -351,6 +350,9 @@ static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) + tp->ecn_flags = TCP_ECN_OK; + if (tcp_ca_needs_ecn(sk) || bpf_needs_ecn) + INET_ECN_xmit(sk); ++ ++ if (dst) ++ tcp_set_ecn_low_from_dst(sk, dst); + } + } + +@@ -388,7 +390,8 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, + th->cwr = 1; + skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; + } +- } else if (!tcp_ca_needs_ecn(sk)) { ++ } else if (!(tp->ecn_flags & TCP_ECN_ECT_PERMANENT) && ++ !tcp_ca_needs_ecn(sk)) { + /* ACK or retransmitted segment: clear ECT|CE */ + INET_ECN_dontxmit(sk); + } +@@ -1601,7 +1604,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, + { + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *buff; +- int old_factor; ++ int old_factor, inflight_prev; + long limit; + int nlen; + u8 flags; +@@ -1676,6 +1679,30 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, + + if (diff) + tcp_adjust_pcount(sk, skb, diff); ++ ++ inflight_prev = TCP_SKB_CB(skb)->tx.in_flight - old_factor; ++ if (inflight_prev < 0) { ++ WARN_ONCE(tcp_skb_tx_in_flight_is_suspicious( ++ old_factor, ++ TCP_SKB_CB(skb)->sacked, ++ TCP_SKB_CB(skb)->tx.in_flight), ++ "inconsistent: tx.in_flight: %u " ++ "old_factor: %d mss: %u sacked: %u " ++ "1st pcount: %d 2nd pcount: %d " ++ "1st len: %u 2nd len: %u ", ++ TCP_SKB_CB(skb)->tx.in_flight, old_factor, ++ mss_now, TCP_SKB_CB(skb)->sacked, ++ tcp_skb_pcount(skb), tcp_skb_pcount(buff), ++ skb->len, buff->len); ++ inflight_prev = 0; ++ } ++ /* Set 1st tx.in_flight as if 1st were sent by itself: */ ++ TCP_SKB_CB(skb)->tx.in_flight = inflight_prev + ++ tcp_skb_pcount(skb); ++ /* Set 2nd tx.in_flight with new 1st and 2nd pcounts: */ ++ TCP_SKB_CB(buff)->tx.in_flight = inflight_prev + ++ tcp_skb_pcount(skb) + ++ tcp_skb_pcount(buff); + } + + /* Link BUFF into the send queue. */ +@@ -2033,13 +2060,12 @@ static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, + static u32 tcp_tso_segs(struct sock *sk, unsigned int mss_now) + { + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; +- u32 min_tso, tso_segs; +- +- min_tso = ca_ops->min_tso_segs ? +- ca_ops->min_tso_segs(sk) : +- READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); ++ u32 tso_segs; + +- tso_segs = tcp_tso_autosize(sk, mss_now, min_tso); ++ tso_segs = ca_ops->tso_segs ? ++ ca_ops->tso_segs(sk, mss_now) : ++ tcp_tso_autosize(sk, mss_now, ++ sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + return min_t(u32, tso_segs, sk->sk_gso_max_segs); + } + +@@ -2767,6 +2793,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, + skb_set_delivery_time(skb, tp->tcp_wstamp_ns, true); + list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); + tcp_init_tso_segs(skb, mss_now); ++ tcp_set_tx_in_flight(sk, skb); + goto repair; /* Skip network transmission */ + } + +@@ -2981,6 +3008,7 @@ void tcp_send_loss_probe(struct sock *sk) + if (WARN_ON(!skb || !tcp_skb_pcount(skb))) + goto rearm_timer; + ++ tp->tlp_orig_data_app_limited = TCP_SKB_CB(skb)->tx.is_app_limited; + if (__tcp_retransmit_skb(sk, skb, 1)) + goto rearm_timer; + +diff --git a/net/ipv4/tcp_rate.c b/net/ipv4/tcp_rate.c +index a8f6d9d06f2e..8737f2134648 100644 +--- a/net/ipv4/tcp_rate.c ++++ b/net/ipv4/tcp_rate.c +@@ -34,6 +34,24 @@ + * ready to send in the write queue. + */ + ++void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb) ++{ ++ struct tcp_sock *tp = tcp_sk(sk); ++ u32 in_flight; ++ ++ /* Check, sanitize, and record packets in flight after skb was sent. */ ++ in_flight = tcp_packets_in_flight(tp) + tcp_skb_pcount(skb); ++ if (WARN_ONCE(in_flight > TCPCB_IN_FLIGHT_MAX, ++ "insane in_flight %u cc %s mss %u " ++ "cwnd %u pif %u %u %u %u\n", ++ in_flight, inet_csk(sk)->icsk_ca_ops->name, ++ tp->mss_cache, tp->snd_cwnd, ++ tp->packets_out, tp->retrans_out, ++ tp->sacked_out, tp->lost_out)) ++ in_flight = TCPCB_IN_FLIGHT_MAX; ++ TCP_SKB_CB(skb)->tx.in_flight = in_flight; ++} ++ + /* Snapshot the current delivery information in the skb, to generate + * a rate sample later when the skb is (s)acked in tcp_rate_skb_delivered(). + */ +@@ -66,7 +84,9 @@ void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb) + TCP_SKB_CB(skb)->tx.delivered_mstamp = tp->delivered_mstamp; + TCP_SKB_CB(skb)->tx.delivered = tp->delivered; + TCP_SKB_CB(skb)->tx.delivered_ce = tp->delivered_ce; ++ TCP_SKB_CB(skb)->tx.lost = tp->lost; + TCP_SKB_CB(skb)->tx.is_app_limited = tp->app_limited ? 1 : 0; ++ tcp_set_tx_in_flight(sk, skb); + } + + /* When an skb is sacked or acked, we fill in the rate sample with the (prior) +@@ -91,18 +111,21 @@ void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, + if (!rs->prior_delivered || + tcp_skb_sent_after(tx_tstamp, tp->first_tx_mstamp, + scb->end_seq, rs->last_end_seq)) { ++ rs->prior_lost = scb->tx.lost; + rs->prior_delivered_ce = scb->tx.delivered_ce; + rs->prior_delivered = scb->tx.delivered; + rs->prior_mstamp = scb->tx.delivered_mstamp; + rs->is_app_limited = scb->tx.is_app_limited; + rs->is_retrans = scb->sacked & TCPCB_RETRANS; ++ rs->tx_in_flight = scb->tx.in_flight; + rs->last_end_seq = scb->end_seq; + + /* Record send time of most recently ACKed packet: */ + tp->first_tx_mstamp = tx_tstamp; + /* Find the duration of the "send phase" of this window: */ +- rs->interval_us = tcp_stamp_us_delta(tp->first_tx_mstamp, +- scb->tx.first_tx_mstamp); ++ rs->interval_us = tcp_stamp32_us_delta( ++ tp->first_tx_mstamp, ++ scb->tx.first_tx_mstamp); + + } + /* Mark off the skb delivered once it's sacked to avoid being +@@ -144,6 +167,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, + return; + } + rs->delivered = tp->delivered - rs->prior_delivered; ++ rs->lost = tp->lost - rs->prior_lost; + + rs->delivered_ce = tp->delivered_ce - rs->prior_delivered_ce; + /* delivered_ce occupies less than 32 bits in the skb control block */ +@@ -155,7 +179,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, + * longer phase. + */ + snd_us = rs->interval_us; /* send phase */ +- ack_us = tcp_stamp_us_delta(tp->tcp_mstamp, ++ ack_us = tcp_stamp32_us_delta(tp->tcp_mstamp, + rs->prior_mstamp); /* ack phase */ + rs->interval_us = max(snd_us, ack_us); + +diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c +index 4d40615dc8fc..f27941201ef2 100644 +--- a/net/ipv4/tcp_timer.c ++++ b/net/ipv4/tcp_timer.c +@@ -689,6 +689,7 @@ void tcp_write_timer_handler(struct sock *sk) + return; + } + ++ tcp_rate_check_app_limited(sk); + tcp_mstamp_refresh(tcp_sk(sk)); + event = icsk->icsk_pending; + +-- +2.46.0 + +From 6f3d18e8b87fe3d732050b597e42200fc22e0d01 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:18:51 +0200 +Subject: [PATCH 03/12] block + +Signed-off-by: Peter Jung +--- + block/bfq-iosched.c | 120 ++++++++++++++++++++++++++++++++++++-------- + block/bfq-iosched.h | 16 +++++- + block/mq-deadline.c | 110 +++++++++++++++++++++++++++++++++------- + 3 files changed, 203 insertions(+), 43 deletions(-) + +diff --git a/block/bfq-iosched.c b/block/bfq-iosched.c +index 4b88a54a9b76..88df08a246fa 100644 +--- a/block/bfq-iosched.c ++++ b/block/bfq-iosched.c +@@ -467,6 +467,21 @@ static struct bfq_io_cq *bfq_bic_lookup(struct request_queue *q) + return icq; + } + ++static struct bfq_io_cq *bfq_bic_try_lookup(struct request_queue *q) ++{ ++ if (!current->io_context) ++ return NULL; ++ if (spin_trylock_irq(&q->queue_lock)) { ++ struct bfq_io_cq *icq; ++ ++ icq = icq_to_bic(ioc_lookup_icq(q)); ++ spin_unlock_irq(&q->queue_lock); ++ return icq; ++ } ++ ++ return NULL; ++} ++ + /* + * Scheduler run of queue, if there are requests pending and no one in the + * driver that will restart queueing. +@@ -2454,10 +2469,21 @@ static bool bfq_bio_merge(struct request_queue *q, struct bio *bio, + * returned by bfq_bic_lookup does not go away before + * bfqd->lock is taken. + */ +- struct bfq_io_cq *bic = bfq_bic_lookup(q); ++ struct bfq_io_cq *bic = bfq_bic_try_lookup(q); + bool ret; + +- spin_lock_irq(&bfqd->lock); ++ /* ++ * bio merging is called for every bio queued, and it's very easy ++ * to run into contention because of that. If we fail getting ++ * the dd lock, just skip this merge attempt. For related IO, the ++ * plug will be the successful merging point. If we get here, we ++ * already failed doing the obvious merge. Chances of actually ++ * getting a merge off this path is a lot slimmer, so skipping an ++ * occassional lookup that will most likely not succeed anyway should ++ * not be a problem. ++ */ ++ if (!spin_trylock_irq(&bfqd->lock)) ++ return false; + + if (bic) { + /* +@@ -5148,6 +5174,10 @@ static bool bfq_has_work(struct blk_mq_hw_ctx *hctx) + { + struct bfq_data *bfqd = hctx->queue->elevator->elevator_data; + ++ if (!list_empty_careful(&bfqd->at_head) || ++ !list_empty_careful(&bfqd->at_tail)) ++ return true; ++ + /* + * Avoiding lock: a race on bfqd->queued should cause at + * most a call to dispatch for nothing +@@ -5297,15 +5327,61 @@ static inline void bfq_update_dispatch_stats(struct request_queue *q, + bool idle_timer_disabled) {} + #endif /* CONFIG_BFQ_CGROUP_DEBUG */ + ++static void bfq_insert_request(struct request_queue *q, struct request *rq, ++ blk_insert_t flags, struct list_head *free); ++ ++static void __bfq_do_insert(struct request_queue *q, blk_insert_t flags, ++ struct list_head *list, struct list_head *free) ++{ ++ while (!list_empty(list)) { ++ struct request *rq; ++ ++ rq = list_first_entry(list, struct request, queuelist); ++ list_del_init(&rq->queuelist); ++ bfq_insert_request(q, rq, flags, free); ++ } ++} ++ ++static void bfq_do_insert(struct request_queue *q, struct list_head *free) ++{ ++ struct bfq_data *bfqd = q->elevator->elevator_data; ++ LIST_HEAD(at_head); ++ LIST_HEAD(at_tail); ++ ++ spin_lock(&bfqd->insert_lock); ++ list_splice_init(&bfqd->at_head, &at_head); ++ list_splice_init(&bfqd->at_tail, &at_tail); ++ spin_unlock(&bfqd->insert_lock); ++ ++ __bfq_do_insert(q, BLK_MQ_INSERT_AT_HEAD, &at_head, free); ++ __bfq_do_insert(q, 0, &at_tail, free); ++} ++ + static struct request *bfq_dispatch_request(struct blk_mq_hw_ctx *hctx) + { +- struct bfq_data *bfqd = hctx->queue->elevator->elevator_data; ++ struct request_queue *q = hctx->queue; ++ struct bfq_data *bfqd = q->elevator->elevator_data; + struct request *rq; + struct bfq_queue *in_serv_queue; + bool waiting_rq, idle_timer_disabled = false; ++ LIST_HEAD(free); ++ ++ /* ++ * If someone else is already dispatching, skip this one. This will ++ * defer the next dispatch event to when something completes, and could ++ * potentially lower the queue depth for contended cases. ++ * ++ * See the logic in blk_mq_do_dispatch_sched(), which loops and ++ * retries if nothing is dispatched. ++ */ ++ if (test_bit(BFQ_DISPATCHING, &bfqd->run_state) || ++ test_and_set_bit_lock(BFQ_DISPATCHING, &bfqd->run_state)) ++ return NULL; + + spin_lock_irq(&bfqd->lock); + ++ bfq_do_insert(hctx->queue, &free); ++ + in_serv_queue = bfqd->in_service_queue; + waiting_rq = in_serv_queue && bfq_bfqq_wait_request(in_serv_queue); + +@@ -5315,7 +5391,9 @@ static struct request *bfq_dispatch_request(struct blk_mq_hw_ctx *hctx) + waiting_rq && !bfq_bfqq_wait_request(in_serv_queue); + } + ++ clear_bit_unlock(BFQ_DISPATCHING, &bfqd->run_state); + spin_unlock_irq(&bfqd->lock); ++ blk_mq_free_requests(&free); + bfq_update_dispatch_stats(hctx->queue, rq, + idle_timer_disabled ? in_serv_queue : NULL, + idle_timer_disabled); +@@ -6236,27 +6314,21 @@ static inline void bfq_update_insert_stats(struct request_queue *q, + + static struct bfq_queue *bfq_init_rq(struct request *rq); + +-static void bfq_insert_request(struct blk_mq_hw_ctx *hctx, struct request *rq, +- blk_insert_t flags) ++static void bfq_insert_request(struct request_queue *q, struct request *rq, ++ blk_insert_t flags, struct list_head *free) + { +- struct request_queue *q = hctx->queue; + struct bfq_data *bfqd = q->elevator->elevator_data; + struct bfq_queue *bfqq; + bool idle_timer_disabled = false; + blk_opf_t cmd_flags; +- LIST_HEAD(free); + + #ifdef CONFIG_BFQ_GROUP_IOSCHED + if (!cgroup_subsys_on_dfl(io_cgrp_subsys) && rq->bio) + bfqg_stats_update_legacy_io(q, rq); + #endif +- spin_lock_irq(&bfqd->lock); + bfqq = bfq_init_rq(rq); +- if (blk_mq_sched_try_insert_merge(q, rq, &free)) { +- spin_unlock_irq(&bfqd->lock); +- blk_mq_free_requests(&free); ++ if (blk_mq_sched_try_insert_merge(q, rq, free)) + return; +- } + + trace_block_rq_insert(rq); + +@@ -6286,8 +6358,6 @@ static void bfq_insert_request(struct blk_mq_hw_ctx *hctx, struct request *rq, + * merge). + */ + cmd_flags = rq->cmd_flags; +- spin_unlock_irq(&bfqd->lock); +- + bfq_update_insert_stats(q, bfqq, idle_timer_disabled, + cmd_flags); + } +@@ -6296,13 +6366,15 @@ static void bfq_insert_requests(struct blk_mq_hw_ctx *hctx, + struct list_head *list, + blk_insert_t flags) + { +- while (!list_empty(list)) { +- struct request *rq; ++ struct request_queue *q = hctx->queue; ++ struct bfq_data *bfqd = q->elevator->elevator_data; + +- rq = list_first_entry(list, struct request, queuelist); +- list_del_init(&rq->queuelist); +- bfq_insert_request(hctx, rq, flags); +- } ++ spin_lock_irq(&bfqd->insert_lock); ++ if (flags & BLK_MQ_INSERT_AT_HEAD) ++ list_splice_init(list, &bfqd->at_head); ++ else ++ list_splice_init(list, &bfqd->at_tail); ++ spin_unlock_irq(&bfqd->insert_lock); + } + + static void bfq_update_hw_tag(struct bfq_data *bfqd) +@@ -7211,6 +7283,12 @@ static int bfq_init_queue(struct request_queue *q, struct elevator_type *e) + q->elevator = eq; + spin_unlock_irq(&q->queue_lock); + ++ spin_lock_init(&bfqd->lock); ++ spin_lock_init(&bfqd->insert_lock); ++ ++ INIT_LIST_HEAD(&bfqd->at_head); ++ INIT_LIST_HEAD(&bfqd->at_tail); ++ + /* + * Our fallback bfqq if bfq_find_alloc_queue() runs into OOM issues. + * Grab a permanent reference to it, so that the normal code flow +@@ -7329,8 +7407,6 @@ static int bfq_init_queue(struct request_queue *q, struct elevator_type *e) + /* see comments on the definition of next field inside bfq_data */ + bfqd->actuator_load_threshold = 4; + +- spin_lock_init(&bfqd->lock); +- + /* + * The invocation of the next bfq_create_group_hierarchy + * function is the head of a chain of function calls +diff --git a/block/bfq-iosched.h b/block/bfq-iosched.h +index 467e8cfc41a2..f44f5d4ec2f4 100644 +--- a/block/bfq-iosched.h ++++ b/block/bfq-iosched.h +@@ -504,12 +504,26 @@ struct bfq_io_cq { + unsigned int requests; /* Number of requests this process has in flight */ + }; + ++enum { ++ BFQ_DISPATCHING = 0, ++}; ++ + /** + * struct bfq_data - per-device data structure. + * + * All the fields are protected by @lock. + */ + struct bfq_data { ++ struct { ++ spinlock_t lock; ++ spinlock_t insert_lock; ++ } ____cacheline_aligned_in_smp; ++ ++ unsigned long run_state; ++ ++ struct list_head at_head; ++ struct list_head at_tail; ++ + /* device request queue */ + struct request_queue *queue; + /* dispatch queue */ +@@ -795,8 +809,6 @@ struct bfq_data { + /* fallback dummy bfqq for extreme OOM conditions */ + struct bfq_queue oom_bfqq; + +- spinlock_t lock; +- + /* + * bic associated with the task issuing current bio for + * merging. This and the next field are used as a support to +diff --git a/block/mq-deadline.c b/block/mq-deadline.c +index acdc28756d9d..8b214233a061 100644 +--- a/block/mq-deadline.c ++++ b/block/mq-deadline.c +@@ -79,10 +79,23 @@ struct dd_per_prio { + struct io_stats_per_prio stats; + }; + ++enum { ++ DD_DISPATCHING = 0, ++}; ++ + struct deadline_data { + /* + * run time data + */ ++ struct { ++ spinlock_t lock; ++ spinlock_t insert_lock; ++ } ____cacheline_aligned_in_smp; ++ ++ unsigned long run_state; ++ ++ struct list_head at_head; ++ struct list_head at_tail; + + struct dd_per_prio per_prio[DD_PRIO_COUNT]; + +@@ -100,8 +113,6 @@ struct deadline_data { + int front_merges; + u32 async_depth; + int prio_aging_expire; +- +- spinlock_t lock; + }; + + /* Maps an I/O priority class to a deadline scheduler priority. */ +@@ -112,6 +123,9 @@ static const enum dd_prio ioprio_class_to_prio[] = { + [IOPRIO_CLASS_IDLE] = DD_IDLE_PRIO, + }; + ++static void dd_insert_request(struct request_queue *q, struct request *rq, ++ blk_insert_t flags, struct list_head *free); ++ + static inline struct rb_root * + deadline_rb_root(struct dd_per_prio *per_prio, struct request *rq) + { +@@ -451,6 +465,33 @@ static struct request *dd_dispatch_prio_aged_requests(struct deadline_data *dd, + return NULL; + } + ++static void __dd_do_insert(struct request_queue *q, blk_insert_t flags, ++ struct list_head *list, struct list_head *free) ++{ ++ while (!list_empty(list)) { ++ struct request *rq; ++ ++ rq = list_first_entry(list, struct request, queuelist); ++ list_del_init(&rq->queuelist); ++ dd_insert_request(q, rq, flags, free); ++ } ++} ++ ++static void dd_do_insert(struct request_queue *q, struct list_head *free) ++{ ++ struct deadline_data *dd = q->elevator->elevator_data; ++ LIST_HEAD(at_head); ++ LIST_HEAD(at_tail); ++ ++ spin_lock(&dd->insert_lock); ++ list_splice_init(&dd->at_head, &at_head); ++ list_splice_init(&dd->at_tail, &at_tail); ++ spin_unlock(&dd->insert_lock); ++ ++ __dd_do_insert(q, BLK_MQ_INSERT_AT_HEAD, &at_head, free); ++ __dd_do_insert(q, 0, &at_tail, free); ++} ++ + /* + * Called from blk_mq_run_hw_queue() -> __blk_mq_sched_dispatch_requests(). + * +@@ -461,12 +502,27 @@ static struct request *dd_dispatch_prio_aged_requests(struct deadline_data *dd, + */ + static struct request *dd_dispatch_request(struct blk_mq_hw_ctx *hctx) + { +- struct deadline_data *dd = hctx->queue->elevator->elevator_data; ++ struct request_queue *q = hctx->queue; ++ struct deadline_data *dd = q->elevator->elevator_data; + const unsigned long now = jiffies; + struct request *rq; + enum dd_prio prio; ++ LIST_HEAD(free); ++ ++ /* ++ * If someone else is already dispatching, skip this one. This will ++ * defer the next dispatch event to when something completes, and could ++ * potentially lower the queue depth for contended cases. ++ * ++ * See the logic in blk_mq_do_dispatch_sched(), which loops and ++ * retries if nothing is dispatched. ++ */ ++ if (test_bit(DD_DISPATCHING, &dd->run_state) || ++ test_and_set_bit_lock(DD_DISPATCHING, &dd->run_state)) ++ return NULL; + + spin_lock(&dd->lock); ++ dd_do_insert(q, &free); + rq = dd_dispatch_prio_aged_requests(dd, now); + if (rq) + goto unlock; +@@ -482,8 +538,10 @@ static struct request *dd_dispatch_request(struct blk_mq_hw_ctx *hctx) + } + + unlock: ++ clear_bit_unlock(DD_DISPATCHING, &dd->run_state); + spin_unlock(&dd->lock); + ++ blk_mq_free_requests(&free); + return rq; + } + +@@ -585,6 +643,12 @@ static int dd_init_sched(struct request_queue *q, struct elevator_type *e) + + eq->elevator_data = dd; + ++ spin_lock_init(&dd->lock); ++ spin_lock_init(&dd->insert_lock); ++ ++ INIT_LIST_HEAD(&dd->at_head); ++ INIT_LIST_HEAD(&dd->at_tail); ++ + for (prio = 0; prio <= DD_PRIO_MAX; prio++) { + struct dd_per_prio *per_prio = &dd->per_prio[prio]; + +@@ -601,7 +665,6 @@ static int dd_init_sched(struct request_queue *q, struct elevator_type *e) + dd->last_dir = DD_WRITE; + dd->fifo_batch = fifo_batch; + dd->prio_aging_expire = prio_aging_expire; +- spin_lock_init(&dd->lock); + + /* We dispatch from request queue wide instead of hw queue */ + blk_queue_flag_set(QUEUE_FLAG_SQ_SCHED, q); +@@ -657,7 +720,19 @@ static bool dd_bio_merge(struct request_queue *q, struct bio *bio, + struct request *free = NULL; + bool ret; + +- spin_lock(&dd->lock); ++ /* ++ * bio merging is called for every bio queued, and it's very easy ++ * to run into contention because of that. If we fail getting ++ * the dd lock, just skip this merge attempt. For related IO, the ++ * plug will be the successful merging point. If we get here, we ++ * already failed doing the obvious merge. Chances of actually ++ * getting a merge off this path is a lot slimmer, so skipping an ++ * occassional lookup that will most likely not succeed anyway should ++ * not be a problem. ++ */ ++ if (!spin_trylock(&dd->lock)) ++ return false; ++ + ret = blk_mq_sched_try_merge(q, bio, nr_segs, &free); + spin_unlock(&dd->lock); + +@@ -670,10 +745,9 @@ static bool dd_bio_merge(struct request_queue *q, struct bio *bio, + /* + * add rq to rbtree and fifo + */ +-static void dd_insert_request(struct blk_mq_hw_ctx *hctx, struct request *rq, ++static void dd_insert_request(struct request_queue *q, struct request *rq, + blk_insert_t flags, struct list_head *free) + { +- struct request_queue *q = hctx->queue; + struct deadline_data *dd = q->elevator->elevator_data; + const enum dd_data_dir data_dir = rq_data_dir(rq); + u16 ioprio = req_get_ioprio(rq); +@@ -727,19 +801,13 @@ static void dd_insert_requests(struct blk_mq_hw_ctx *hctx, + { + struct request_queue *q = hctx->queue; + struct deadline_data *dd = q->elevator->elevator_data; +- LIST_HEAD(free); +- +- spin_lock(&dd->lock); +- while (!list_empty(list)) { +- struct request *rq; +- +- rq = list_first_entry(list, struct request, queuelist); +- list_del_init(&rq->queuelist); +- dd_insert_request(hctx, rq, flags, &free); +- } +- spin_unlock(&dd->lock); + +- blk_mq_free_requests(&free); ++ spin_lock(&dd->insert_lock); ++ if (flags & BLK_MQ_INSERT_AT_HEAD) ++ list_splice_init(list, &dd->at_head); ++ else ++ list_splice_init(list, &dd->at_tail); ++ spin_unlock(&dd->insert_lock); + } + + /* Callback from inside blk_mq_rq_ctx_init(). */ +@@ -780,6 +848,10 @@ static bool dd_has_work(struct blk_mq_hw_ctx *hctx) + struct deadline_data *dd = hctx->queue->elevator->elevator_data; + enum dd_prio prio; + ++ if (!list_empty_careful(&dd->at_head) || ++ !list_empty_careful(&dd->at_tail)) ++ return true; ++ + for (prio = 0; prio <= DD_PRIO_MAX; prio++) + if (dd_has_work_for_prio(&dd->per_prio[prio])) + return true; +-- +2.46.0 + +From 693c2b3378b5d86ce5ed5236d114da0606e0ee76 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:19:02 +0200 +Subject: [PATCH 04/12] cachy + +Signed-off-by: Peter Jung +--- + .gitignore | 6 + + .../admin-guide/kernel-parameters.txt | 12 + + MAINTAINERS | 7 + + Makefile | 9 +- + arch/x86/Kconfig.cpu | 432 ++- + arch/x86/Makefile | 45 +- + arch/x86/include/asm/pci.h | 6 + + arch/x86/include/asm/vermagic.h | 76 + + arch/x86/pci/common.c | 7 +- + arch/x86/xen/setup.c | 5 +- + block/bfq-iosched.c | 6 + + block/elevator.c | 10 + + drivers/Makefile | 13 +- + drivers/ata/ahci.c | 23 +- + drivers/cpufreq/Kconfig.x86 | 2 - + drivers/cpufreq/cpufreq.c | 27 +- + drivers/cpufreq/intel_pstate.c | 2 + + drivers/gpu/drm/amd/amdgpu/amdgpu.h | 1 + + drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 10 + + drivers/gpu/drm/amd/display/Kconfig | 6 + + .../gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c | 2 +- + .../amd/display/amdgpu_dm/amdgpu_dm_color.c | 2 +- + .../amd/display/amdgpu_dm/amdgpu_dm_crtc.c | 6 +- + .../amd/display/amdgpu_dm/amdgpu_dm_plane.c | 6 +- + drivers/gpu/drm/amd/pm/amdgpu_pm.c | 3 + + drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c | 14 +- + drivers/i2c/busses/Kconfig | 9 + + drivers/i2c/busses/Makefile | 1 + + drivers/i2c/busses/i2c-nct6775.c | 648 ++++ + drivers/i2c/busses/i2c-piix4.c | 4 +- + drivers/input/evdev.c | 19 +- + drivers/md/dm-crypt.c | 5 + + drivers/media/v4l2-core/Kconfig | 5 + + drivers/media/v4l2-core/Makefile | 2 + + drivers/media/v4l2-core/v4l2loopback.c | 3184 +++++++++++++++++ + drivers/media/v4l2-core/v4l2loopback.h | 98 + + .../media/v4l2-core/v4l2loopback_formats.h | 445 +++ + drivers/pci/controller/Makefile | 6 + + drivers/pci/controller/intel-nvme-remap.c | 462 +++ + drivers/pci/quirks.c | 101 + + include/linux/cpufreq.h | 6 - + include/linux/minmax.h | 7 + + include/linux/pageblock-flags.h | 4 +- + include/linux/pagemap.h | 2 +- + include/linux/user_namespace.h | 4 + + init/Kconfig | 26 + + kernel/Kconfig.hz | 24 + + kernel/fork.c | 14 + + kernel/sched/fair.c | 13 + + kernel/sched/sched.h | 2 +- + kernel/sysctl.c | 12 + + kernel/user_namespace.c | 7 + + mm/Kconfig | 2 +- + mm/compaction.c | 4 + + mm/huge_memory.c | 4 + + mm/page-writeback.c | 8 + + mm/page_alloc.c | 4 + + mm/swap.c | 5 + + mm/vmpressure.c | 4 + + mm/vmscan.c | 8 + + scripts/Makefile.package | 14 + + scripts/package/PKGBUILD | 108 + + 62 files changed, 5910 insertions(+), 99 deletions(-) + create mode 100644 drivers/i2c/busses/i2c-nct6775.c + create mode 100644 drivers/media/v4l2-core/v4l2loopback.c + create mode 100644 drivers/media/v4l2-core/v4l2loopback.h + create mode 100644 drivers/media/v4l2-core/v4l2loopback_formats.h + create mode 100644 drivers/pci/controller/intel-nvme-remap.c + create mode 100644 scripts/package/PKGBUILD + +diff --git a/.gitignore b/.gitignore +index c59dc60ba62e..7902adf4f7f1 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -92,6 +92,12 @@ modules.order + # + /tar-install/ + ++# ++# pacman files (make pacman-pkg) ++# ++/PKGBUILD ++/pacman/ ++ + # + # We don't want to ignore the following even if they are dot-files + # +diff --git a/Documentation/admin-guide/kernel-parameters.txt b/Documentation/admin-guide/kernel-parameters.txt +index c82446cef8e2..ab5ca3af35d2 100644 +--- a/Documentation/admin-guide/kernel-parameters.txt ++++ b/Documentation/admin-guide/kernel-parameters.txt +@@ -2229,6 +2229,9 @@ + disable + Do not enable intel_pstate as the default + scaling driver for the supported processors ++ enable ++ Enable intel_pstate in-case "disable" was passed ++ previously in the kernel boot parameters + active + Use intel_pstate driver to bypass the scaling + governors layer of cpufreq and provides it own +@@ -4447,6 +4450,15 @@ + nomsi [MSI] If the PCI_MSI kernel config parameter is + enabled, this kernel boot option can be used to + disable the use of MSI interrupts system-wide. ++ pcie_acs_override = ++ [PCIE] Override missing PCIe ACS support for: ++ downstream ++ All downstream ports - full ACS capabilities ++ multfunction ++ All multifunction devices - multifunction ACS subset ++ id:nnnn:nnnn ++ Specfic device - full ACS capabilities ++ Specified as vid:did (vendor/device ID) in hex + noioapicquirk [APIC] Disable all boot interrupt quirks. + Safety option to keep boot IRQs enabled. This + should never be necessary. +diff --git a/MAINTAINERS b/MAINTAINERS +index 958e935449e5..b27470be2e6a 100644 +--- a/MAINTAINERS ++++ b/MAINTAINERS +@@ -11978,6 +11978,13 @@ F: include/uapi/linux/nfsd/ + F: include/uapi/linux/sunrpc/ + F: net/sunrpc/ + ++KERNEL PACMAN PACKAGING (in addition to generic KERNEL BUILD) ++M: Thomas Weißschuh ++R: Christian Heusel ++R: Nathan Chancellor ++S: Maintained ++F: scripts/package/PKGBUILD ++ + KERNEL REGRESSIONS + M: Thorsten Leemhuis + L: regressions@lists.linux.dev +diff --git a/Makefile b/Makefile +index f9badb79ae8f..fbe293960f60 100644 +--- a/Makefile ++++ b/Makefile +@@ -817,6 +817,9 @@ KBUILD_CFLAGS += -fno-delete-null-pointer-checks + ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE + KBUILD_CFLAGS += -O2 + KBUILD_RUSTFLAGS += -Copt-level=2 ++else ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3 ++KBUILD_CFLAGS += -O3 ++KBUILD_RUSTFLAGS += -Copt-level=3 + else ifdef CONFIG_CC_OPTIMIZE_FOR_SIZE + KBUILD_CFLAGS += -Os + KBUILD_RUSTFLAGS += -Copt-level=s +@@ -1005,9 +1008,9 @@ KBUILD_CFLAGS += -fno-strict-overflow + # Make sure -fstack-check isn't enabled (like gentoo apparently did) + KBUILD_CFLAGS += -fno-stack-check + +-# conserve stack if available ++# conserve stack, ivopts and modulo-sched if available + ifdef CONFIG_CC_IS_GCC +-KBUILD_CFLAGS += -fconserve-stack ++KBUILD_CFLAGS += -fconserve-stack -fivopts -fmodulo-sched -fno-tree-vectorize + endif + + # change __FILE__ to the relative path from the srctree +@@ -1497,7 +1500,7 @@ CLEAN_FILES += vmlinux.symvers modules-only.symvers \ + # Directories & files removed with 'make mrproper' + MRPROPER_FILES += include/config include/generated \ + arch/$(SRCARCH)/include/generated .objdiff \ +- debian snap tar-install \ ++ debian snap tar-install PKGBUILD pacman \ + .config .config.old .version \ + Module.symvers \ + certs/signing_key.pem \ +diff --git a/arch/x86/Kconfig.cpu b/arch/x86/Kconfig.cpu +index 2a7279d80460..3b077b9f9291 100644 +--- a/arch/x86/Kconfig.cpu ++++ b/arch/x86/Kconfig.cpu +@@ -157,7 +157,7 @@ config MPENTIUM4 + + + config MK6 +- bool "K6/K6-II/K6-III" ++ bool "AMD K6/K6-II/K6-III" + depends on X86_32 + help + Select this for an AMD K6-family processor. Enables use of +@@ -165,7 +165,7 @@ config MK6 + flags to GCC. + + config MK7 +- bool "Athlon/Duron/K7" ++ bool "AMD Athlon/Duron/K7" + depends on X86_32 + help + Select this for an AMD Athlon K7-family processor. Enables use of +@@ -173,12 +173,114 @@ config MK7 + flags to GCC. + + config MK8 +- bool "Opteron/Athlon64/Hammer/K8" ++ bool "AMD Opteron/Athlon64/Hammer/K8" + help + Select this for an AMD Opteron or Athlon64 Hammer-family processor. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + ++config MK8SSE3 ++ bool "AMD Opteron/Athlon64/Hammer/K8 with SSE3" ++ help ++ Select this for improved AMD Opteron or Athlon64 Hammer-family processors. ++ Enables use of some extended instructions, and passes appropriate ++ optimization flags to GCC. ++ ++config MK10 ++ bool "AMD 61xx/7x50/PhenomX3/X4/II/K10" ++ help ++ Select this for an AMD 61xx Eight-Core Magny-Cours, Athlon X2 7x50, ++ Phenom X3/X4/II, Athlon II X2/X3/X4, or Turion II-family processor. ++ Enables use of some extended instructions, and passes appropriate ++ optimization flags to GCC. ++ ++config MBARCELONA ++ bool "AMD Barcelona" ++ help ++ Select this for AMD Family 10h Barcelona processors. ++ ++ Enables -march=barcelona ++ ++config MBOBCAT ++ bool "AMD Bobcat" ++ help ++ Select this for AMD Family 14h Bobcat processors. ++ ++ Enables -march=btver1 ++ ++config MJAGUAR ++ bool "AMD Jaguar" ++ help ++ Select this for AMD Family 16h Jaguar processors. ++ ++ Enables -march=btver2 ++ ++config MBULLDOZER ++ bool "AMD Bulldozer" ++ help ++ Select this for AMD Family 15h Bulldozer processors. ++ ++ Enables -march=bdver1 ++ ++config MPILEDRIVER ++ bool "AMD Piledriver" ++ help ++ Select this for AMD Family 15h Piledriver processors. ++ ++ Enables -march=bdver2 ++ ++config MSTEAMROLLER ++ bool "AMD Steamroller" ++ help ++ Select this for AMD Family 15h Steamroller processors. ++ ++ Enables -march=bdver3 ++ ++config MEXCAVATOR ++ bool "AMD Excavator" ++ help ++ Select this for AMD Family 15h Excavator processors. ++ ++ Enables -march=bdver4 ++ ++config MZEN ++ bool "AMD Zen" ++ help ++ Select this for AMD Family 17h Zen processors. ++ ++ Enables -march=znver1 ++ ++config MZEN2 ++ bool "AMD Zen 2" ++ help ++ Select this for AMD Family 17h Zen 2 processors. ++ ++ Enables -march=znver2 ++ ++config MZEN3 ++ bool "AMD Zen 3" ++ depends on (CC_IS_GCC && GCC_VERSION >= 100300) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ help ++ Select this for AMD Family 19h Zen 3 processors. ++ ++ Enables -march=znver3 ++ ++config MZEN4 ++ bool "AMD Zen 4" ++ depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 160000) ++ help ++ Select this for AMD Family 19h Zen 4 processors. ++ ++ Enables -march=znver4 ++ ++config MZEN5 ++ bool "AMD Zen 5" ++ depends on (CC_IS_GCC && GCC_VERSION >= 140000) || (CC_IS_CLANG && CLANG_VERSION >= 180000) ++ help ++ Select this for AMD Family 1Ah Zen 5 processors. ++ ++ Enables -march=znver5 ++ + config MCRUSOE + bool "Crusoe" + depends on X86_32 +@@ -270,7 +372,7 @@ config MPSC + in /proc/cpuinfo. Family 15 is an older Xeon, Family 6 a newer one. + + config MCORE2 +- bool "Core 2/newer Xeon" ++ bool "Intel Core 2" + help + + Select this for Intel Core 2 and newer Core 2 Xeons (Xeon 51xx and +@@ -278,6 +380,8 @@ config MCORE2 + family in /proc/cpuinfo. Newer ones have 6 and older ones 15 + (not a typo) + ++ Enables -march=core2 ++ + config MATOM + bool "Intel Atom" + help +@@ -287,6 +391,212 @@ config MATOM + accordingly optimized code. Use a recent GCC with specific Atom + support in order to fully benefit from selecting this option. + ++config MNEHALEM ++ bool "Intel Nehalem" ++ select X86_P6_NOP ++ help ++ ++ Select this for 1st Gen Core processors in the Nehalem family. ++ ++ Enables -march=nehalem ++ ++config MWESTMERE ++ bool "Intel Westmere" ++ select X86_P6_NOP ++ help ++ ++ Select this for the Intel Westmere formerly Nehalem-C family. ++ ++ Enables -march=westmere ++ ++config MSILVERMONT ++ bool "Intel Silvermont" ++ select X86_P6_NOP ++ help ++ ++ Select this for the Intel Silvermont platform. ++ ++ Enables -march=silvermont ++ ++config MGOLDMONT ++ bool "Intel Goldmont" ++ select X86_P6_NOP ++ help ++ ++ Select this for the Intel Goldmont platform including Apollo Lake and Denverton. ++ ++ Enables -march=goldmont ++ ++config MGOLDMONTPLUS ++ bool "Intel Goldmont Plus" ++ select X86_P6_NOP ++ help ++ ++ Select this for the Intel Goldmont Plus platform including Gemini Lake. ++ ++ Enables -march=goldmont-plus ++ ++config MSANDYBRIDGE ++ bool "Intel Sandy Bridge" ++ select X86_P6_NOP ++ help ++ ++ Select this for 2nd Gen Core processors in the Sandy Bridge family. ++ ++ Enables -march=sandybridge ++ ++config MIVYBRIDGE ++ bool "Intel Ivy Bridge" ++ select X86_P6_NOP ++ help ++ ++ Select this for 3rd Gen Core processors in the Ivy Bridge family. ++ ++ Enables -march=ivybridge ++ ++config MHASWELL ++ bool "Intel Haswell" ++ select X86_P6_NOP ++ help ++ ++ Select this for 4th Gen Core processors in the Haswell family. ++ ++ Enables -march=haswell ++ ++config MBROADWELL ++ bool "Intel Broadwell" ++ select X86_P6_NOP ++ help ++ ++ Select this for 5th Gen Core processors in the Broadwell family. ++ ++ Enables -march=broadwell ++ ++config MSKYLAKE ++ bool "Intel Skylake" ++ select X86_P6_NOP ++ help ++ ++ Select this for 6th Gen Core processors in the Skylake family. ++ ++ Enables -march=skylake ++ ++config MSKYLAKEX ++ bool "Intel Skylake X" ++ select X86_P6_NOP ++ help ++ ++ Select this for 6th Gen Core processors in the Skylake X family. ++ ++ Enables -march=skylake-avx512 ++ ++config MCANNONLAKE ++ bool "Intel Cannon Lake" ++ select X86_P6_NOP ++ help ++ ++ Select this for 8th Gen Core processors ++ ++ Enables -march=cannonlake ++ ++config MICELAKE ++ bool "Intel Ice Lake" ++ select X86_P6_NOP ++ help ++ ++ Select this for 10th Gen Core processors in the Ice Lake family. ++ ++ Enables -march=icelake-client ++ ++config MCASCADELAKE ++ bool "Intel Cascade Lake" ++ select X86_P6_NOP ++ help ++ ++ Select this for Xeon processors in the Cascade Lake family. ++ ++ Enables -march=cascadelake ++ ++config MCOOPERLAKE ++ bool "Intel Cooper Lake" ++ depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) ++ select X86_P6_NOP ++ help ++ ++ Select this for Xeon processors in the Cooper Lake family. ++ ++ Enables -march=cooperlake ++ ++config MTIGERLAKE ++ bool "Intel Tiger Lake" ++ depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) ++ select X86_P6_NOP ++ help ++ ++ Select this for third-generation 10 nm process processors in the Tiger Lake family. ++ ++ Enables -march=tigerlake ++ ++config MSAPPHIRERAPIDS ++ bool "Intel Sapphire Rapids" ++ depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ select X86_P6_NOP ++ help ++ ++ Select this for fourth-generation 10 nm process processors in the Sapphire Rapids family. ++ ++ Enables -march=sapphirerapids ++ ++config MROCKETLAKE ++ bool "Intel Rocket Lake" ++ depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ select X86_P6_NOP ++ help ++ ++ Select this for eleventh-generation processors in the Rocket Lake family. ++ ++ Enables -march=rocketlake ++ ++config MALDERLAKE ++ bool "Intel Alder Lake" ++ depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ select X86_P6_NOP ++ help ++ ++ Select this for twelfth-generation processors in the Alder Lake family. ++ ++ Enables -march=alderlake ++ ++config MRAPTORLAKE ++ bool "Intel Raptor Lake" ++ depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) ++ select X86_P6_NOP ++ help ++ ++ Select this for thirteenth-generation processors in the Raptor Lake family. ++ ++ Enables -march=raptorlake ++ ++config MMETEORLAKE ++ bool "Intel Meteor Lake" ++ depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) ++ select X86_P6_NOP ++ help ++ ++ Select this for fourteenth-generation processors in the Meteor Lake family. ++ ++ Enables -march=meteorlake ++ ++config MEMERALDRAPIDS ++ bool "Intel Emerald Rapids" ++ depends on (CC_IS_GCC && GCC_VERSION > 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) ++ select X86_P6_NOP ++ help ++ ++ Select this for fifth-generation 10 nm process processors in the Emerald Rapids family. ++ ++ Enables -march=emeraldrapids ++ + config GENERIC_CPU + bool "Generic-x86-64" + depends on X86_64 +@@ -294,6 +604,50 @@ config GENERIC_CPU + Generic x86-64 CPU. + Run equally well on all x86-64 CPUs. + ++config GENERIC_CPU2 ++ bool "Generic-x86-64-v2" ++ depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ depends on X86_64 ++ help ++ Generic x86-64 CPU. ++ Run equally well on all x86-64 CPUs with min support of x86-64-v2. ++ ++config GENERIC_CPU3 ++ bool "Generic-x86-64-v3" ++ depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ depends on X86_64 ++ help ++ Generic x86-64-v3 CPU with v3 instructions. ++ Run equally well on all x86-64 CPUs with min support of x86-64-v3. ++ ++config GENERIC_CPU4 ++ bool "Generic-x86-64-v4" ++ depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) ++ depends on X86_64 ++ help ++ Generic x86-64 CPU with v4 instructions. ++ Run equally well on all x86-64 CPUs with min support of x86-64-v4. ++ ++config MNATIVE_INTEL ++ bool "Intel-Native optimizations autodetected by the compiler" ++ help ++ ++ Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects ++ the optimum settings to use based on your processor. Do NOT use this ++ for AMD CPUs. Intel Only! ++ ++ Enables -march=native ++ ++config MNATIVE_AMD ++ bool "AMD-Native optimizations autodetected by the compiler" ++ help ++ ++ Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects ++ the optimum settings to use based on your processor. Do NOT use this ++ for Intel CPUs. AMD Only! ++ ++ Enables -march=native ++ + endchoice + + config X86_GENERIC +@@ -318,9 +672,17 @@ config X86_INTERNODE_CACHE_SHIFT + config X86_L1_CACHE_SHIFT + int + default "7" if MPENTIUM4 || MPSC +- default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || X86_GENERIC || GENERIC_CPU ++ default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || MK8SSE3 || MK10 \ ++ || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER \ ++ || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT \ ++ || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL \ ++ || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE \ ++ || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE \ ++ || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD || X86_GENERIC || GENERIC_CPU || GENERIC_CPU2 \ ++ || GENERIC_CPU3 || GENERIC_CPU4 + default "4" if MELAN || M486SX || M486 || MGEODEGX1 +- default "5" if MWINCHIP3D || MWINCHIPC6 || MCRUSOE || MEFFICEON || MCYRIXIII || MK6 || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || M586 || MVIAC3_2 || MGEODE_LX ++ default "5" if MWINCHIP3D || MWINCHIPC6 || MCRUSOE || MEFFICEON || MCYRIXIII || MK6 || MPENTIUMIII \ ++ || MPENTIUMII || M686 || M586MMX || M586TSC || M586 || MVIAC3_2 || MGEODE_LX + + config X86_F00F_BUG + def_bool y +@@ -332,15 +694,27 @@ config X86_INVD_BUG + + config X86_ALIGNMENT_16 + def_bool y +- depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MELAN || MK6 || M586MMX || M586TSC || M586 || M486SX || M486 || MVIAC3_2 || MGEODEGX1 ++ depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MELAN || MK6 || M586MMX || M586TSC \ ++ || M586 || M486SX || M486 || MVIAC3_2 || MGEODEGX1 + + config X86_INTEL_USERCOPY + def_bool y +- depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 ++ depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC \ ++ || MK8 || MK7 || MEFFICEON || MCORE2 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT \ ++ || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX \ ++ || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS \ ++ || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL + + config X86_USE_PPRO_CHECKSUM + def_bool y +- depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM ++ depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM \ ++ || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX \ ++ || MCORE2 || MATOM || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER \ ++ || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM \ ++ || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE \ ++ || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE \ ++ || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE \ ++ || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD + + # + # P6_NOPs are a relatively minor optimization that require a family >= +@@ -356,11 +730,22 @@ config X86_USE_PPRO_CHECKSUM + config X86_P6_NOP + def_bool y + depends on X86_64 +- depends on (MCORE2 || MPENTIUM4 || MPSC) ++ depends on (MCORE2 || MPENTIUM4 || MPSC || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT \ ++ || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE \ ++ || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE \ ++ || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS \ ++ || MNATIVE_INTEL) + + config X86_TSC + def_bool y +- depends on (MWINCHIP3D || MCRUSOE || MEFFICEON || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || MK8 || MVIAC3_2 || MVIAC7 || MGEODEGX1 || MGEODE_LX || MCORE2 || MATOM) || X86_64 ++ depends on (MWINCHIP3D || MCRUSOE || MEFFICEON || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM \ ++ || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || MK8 || MVIAC3_2 || MVIAC7 || MGEODEGX1 \ ++ || MGEODE_LX || MCORE2 || MATOM || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER \ ++ || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM \ ++ || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL \ ++ || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE \ ++ || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS \ ++ || MNATIVE_INTEL || MNATIVE_AMD) || X86_64 + + config X86_HAVE_PAE + def_bool y +@@ -368,18 +753,37 @@ config X86_HAVE_PAE + + config X86_CMPXCHG64 + def_bool y +- depends on X86_HAVE_PAE || M586TSC || M586MMX || MK6 || MK7 ++ depends on X86_PAE || X86_64 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 \ ++ || M586TSC || M586MMX || MATOM || MGEODE_LX || MGEODEGX1 || MK6 || MK7 || MK8 || MK8SSE3 || MK10 \ ++ || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN \ ++ || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS \ ++ || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE \ ++ || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE \ ++ || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD + + # this should be set for all -march=.. options where the compiler + # generates cmov. + config X86_CMOV + def_bool y +- depends on (MK8 || MK7 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MVIAC3_2 || MVIAC7 || MCRUSOE || MEFFICEON || X86_64 || MATOM || MGEODE_LX) ++ depends on (MK8 || MK7 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 \ ++ || MVIAC3_2 || MVIAC7 || MCRUSOE || MEFFICEON || X86_64 || MATOM || MGEODE_LX || MK8SSE3 || MK10 \ ++ || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR \ ++ || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT \ ++ || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX \ ++ || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS \ ++ || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD) + + config X86_MINIMUM_CPU_FAMILY + int + default "64" if X86_64 +- default "6" if X86_32 && (MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MVIAC3_2 || MVIAC7 || MEFFICEON || MATOM || MCORE2 || MK7 || MK8) ++ default "6" if X86_32 && (MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 \ ++ || MVIAC3_2 || MVIAC7 || MEFFICEON || MATOM || MCORE2 || MK7 || MK8 || MK8SSE3 \ ++ || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER \ ++ || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT \ ++ || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL \ ++ || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE \ ++ || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MRAPTORLAKE \ ++ || MNATIVE_INTEL || MNATIVE_AMD) + default "5" if X86_32 && X86_CMPXCHG64 + default "4" + +diff --git a/arch/x86/Makefile b/arch/x86/Makefile +index 801fd85c3ef6..93cc88b59cbb 100644 +--- a/arch/x86/Makefile ++++ b/arch/x86/Makefile +@@ -176,8 +176,49 @@ else + # FIXME - should be integrated in Makefile.cpu (Makefile_32.cpu) + cflags-$(CONFIG_MK8) += -march=k8 + cflags-$(CONFIG_MPSC) += -march=nocona +- cflags-$(CONFIG_MCORE2) += -march=core2 +- cflags-$(CONFIG_MATOM) += -march=atom ++ cflags-$(CONFIG_MK8SSE3) += -march=k8-sse3 ++ cflags-$(CONFIG_MK10) += -march=amdfam10 ++ cflags-$(CONFIG_MBARCELONA) += -march=barcelona ++ cflags-$(CONFIG_MBOBCAT) += -march=btver1 ++ cflags-$(CONFIG_MJAGUAR) += -march=btver2 ++ cflags-$(CONFIG_MBULLDOZER) += -march=bdver1 ++ cflags-$(CONFIG_MPILEDRIVER) += -march=bdver2 -mno-tbm ++ cflags-$(CONFIG_MSTEAMROLLER) += -march=bdver3 -mno-tbm ++ cflags-$(CONFIG_MEXCAVATOR) += -march=bdver4 -mno-tbm ++ cflags-$(CONFIG_MZEN) += -march=znver1 ++ cflags-$(CONFIG_MZEN2) += -march=znver2 ++ cflags-$(CONFIG_MZEN3) += -march=znver3 ++ cflags-$(CONFIG_MZEN4) += -march=znver4 ++ cflags-$(CONFIG_MZEN5) += -march=znver5 ++ cflags-$(CONFIG_MNATIVE_INTEL) += -march=native ++ cflags-$(CONFIG_MNATIVE_AMD) += -march=native ++ cflags-$(CONFIG_MATOM) += -march=bonnell ++ cflags-$(CONFIG_MCORE2) += -march=core2 ++ cflags-$(CONFIG_MNEHALEM) += -march=nehalem ++ cflags-$(CONFIG_MWESTMERE) += -march=westmere ++ cflags-$(CONFIG_MSILVERMONT) += -march=silvermont ++ cflags-$(CONFIG_MGOLDMONT) += -march=goldmont ++ cflags-$(CONFIG_MGOLDMONTPLUS) += -march=goldmont-plus ++ cflags-$(CONFIG_MSANDYBRIDGE) += -march=sandybridge ++ cflags-$(CONFIG_MIVYBRIDGE) += -march=ivybridge ++ cflags-$(CONFIG_MHASWELL) += -march=haswell ++ cflags-$(CONFIG_MBROADWELL) += -march=broadwell ++ cflags-$(CONFIG_MSKYLAKE) += -march=skylake ++ cflags-$(CONFIG_MSKYLAKEX) += -march=skylake-avx512 ++ cflags-$(CONFIG_MCANNONLAKE) += -march=cannonlake ++ cflags-$(CONFIG_MICELAKE) += -march=icelake-client ++ cflags-$(CONFIG_MCASCADELAKE) += -march=cascadelake ++ cflags-$(CONFIG_MCOOPERLAKE) += -march=cooperlake ++ cflags-$(CONFIG_MTIGERLAKE) += -march=tigerlake ++ cflags-$(CONFIG_MSAPPHIRERAPIDS) += -march=sapphirerapids ++ cflags-$(CONFIG_MROCKETLAKE) += -march=rocketlake ++ cflags-$(CONFIG_MALDERLAKE) += -march=alderlake ++ cflags-$(CONFIG_MRAPTORLAKE) += -march=raptorlake ++ cflags-$(CONFIG_MMETEORLAKE) += -march=meteorlake ++ cflags-$(CONFIG_MEMERALDRAPIDS) += -march=emeraldrapids ++ cflags-$(CONFIG_GENERIC_CPU2) += -march=x86-64-v2 ++ cflags-$(CONFIG_GENERIC_CPU3) += -march=x86-64-v3 ++ cflags-$(CONFIG_GENERIC_CPU4) += -march=x86-64-v4 + cflags-$(CONFIG_GENERIC_CPU) += -mtune=generic + KBUILD_CFLAGS += $(cflags-y) + +diff --git a/arch/x86/include/asm/pci.h b/arch/x86/include/asm/pci.h +index b3ab80a03365..5e883b397ff3 100644 +--- a/arch/x86/include/asm/pci.h ++++ b/arch/x86/include/asm/pci.h +@@ -26,6 +26,7 @@ struct pci_sysdata { + #if IS_ENABLED(CONFIG_VMD) + struct pci_dev *vmd_dev; /* VMD Device if in Intel VMD domain */ + #endif ++ struct pci_dev *nvme_remap_dev; /* AHCI Device if NVME remapped bus */ + }; + + extern int pci_routeirq; +@@ -69,6 +70,11 @@ static inline bool is_vmd(struct pci_bus *bus) + #define is_vmd(bus) false + #endif /* CONFIG_VMD */ + ++static inline bool is_nvme_remap(struct pci_bus *bus) ++{ ++ return to_pci_sysdata(bus)->nvme_remap_dev != NULL; ++} ++ + /* Can be used to override the logic in pci_scan_bus for skipping + already-configured bus numbers - to be used for buggy BIOSes + or architectures with incomplete PCI setup by the loader */ +diff --git a/arch/x86/include/asm/vermagic.h b/arch/x86/include/asm/vermagic.h +index 75884d2cdec3..7acca9b5a9d5 100644 +--- a/arch/x86/include/asm/vermagic.h ++++ b/arch/x86/include/asm/vermagic.h +@@ -17,6 +17,54 @@ + #define MODULE_PROC_FAMILY "586MMX " + #elif defined CONFIG_MCORE2 + #define MODULE_PROC_FAMILY "CORE2 " ++#elif defined CONFIG_MNATIVE_INTEL ++#define MODULE_PROC_FAMILY "NATIVE_INTEL " ++#elif defined CONFIG_MNATIVE_AMD ++#define MODULE_PROC_FAMILY "NATIVE_AMD " ++#elif defined CONFIG_MNEHALEM ++#define MODULE_PROC_FAMILY "NEHALEM " ++#elif defined CONFIG_MWESTMERE ++#define MODULE_PROC_FAMILY "WESTMERE " ++#elif defined CONFIG_MSILVERMONT ++#define MODULE_PROC_FAMILY "SILVERMONT " ++#elif defined CONFIG_MGOLDMONT ++#define MODULE_PROC_FAMILY "GOLDMONT " ++#elif defined CONFIG_MGOLDMONTPLUS ++#define MODULE_PROC_FAMILY "GOLDMONTPLUS " ++#elif defined CONFIG_MSANDYBRIDGE ++#define MODULE_PROC_FAMILY "SANDYBRIDGE " ++#elif defined CONFIG_MIVYBRIDGE ++#define MODULE_PROC_FAMILY "IVYBRIDGE " ++#elif defined CONFIG_MHASWELL ++#define MODULE_PROC_FAMILY "HASWELL " ++#elif defined CONFIG_MBROADWELL ++#define MODULE_PROC_FAMILY "BROADWELL " ++#elif defined CONFIG_MSKYLAKE ++#define MODULE_PROC_FAMILY "SKYLAKE " ++#elif defined CONFIG_MSKYLAKEX ++#define MODULE_PROC_FAMILY "SKYLAKEX " ++#elif defined CONFIG_MCANNONLAKE ++#define MODULE_PROC_FAMILY "CANNONLAKE " ++#elif defined CONFIG_MICELAKE ++#define MODULE_PROC_FAMILY "ICELAKE " ++#elif defined CONFIG_MCASCADELAKE ++#define MODULE_PROC_FAMILY "CASCADELAKE " ++#elif defined CONFIG_MCOOPERLAKE ++#define MODULE_PROC_FAMILY "COOPERLAKE " ++#elif defined CONFIG_MTIGERLAKE ++#define MODULE_PROC_FAMILY "TIGERLAKE " ++#elif defined CONFIG_MSAPPHIRERAPIDS ++#define MODULE_PROC_FAMILY "SAPPHIRERAPIDS " ++#elif defined CONFIG_ROCKETLAKE ++#define MODULE_PROC_FAMILY "ROCKETLAKE " ++#elif defined CONFIG_MALDERLAKE ++#define MODULE_PROC_FAMILY "ALDERLAKE " ++#elif defined CONFIG_MRAPTORLAKE ++#define MODULE_PROC_FAMILY "RAPTORLAKE " ++#elif defined CONFIG_MMETEORLAKE ++#define MODULE_PROC_FAMILY "METEORLAKE " ++#elif defined CONFIG_MEMERALDRAPIDS ++#define MODULE_PROC_FAMILY "EMERALDRAPIDS " + #elif defined CONFIG_MATOM + #define MODULE_PROC_FAMILY "ATOM " + #elif defined CONFIG_M686 +@@ -35,6 +83,34 @@ + #define MODULE_PROC_FAMILY "K7 " + #elif defined CONFIG_MK8 + #define MODULE_PROC_FAMILY "K8 " ++#elif defined CONFIG_MK8SSE3 ++#define MODULE_PROC_FAMILY "K8SSE3 " ++#elif defined CONFIG_MK10 ++#define MODULE_PROC_FAMILY "K10 " ++#elif defined CONFIG_MBARCELONA ++#define MODULE_PROC_FAMILY "BARCELONA " ++#elif defined CONFIG_MBOBCAT ++#define MODULE_PROC_FAMILY "BOBCAT " ++#elif defined CONFIG_MBULLDOZER ++#define MODULE_PROC_FAMILY "BULLDOZER " ++#elif defined CONFIG_MPILEDRIVER ++#define MODULE_PROC_FAMILY "PILEDRIVER " ++#elif defined CONFIG_MSTEAMROLLER ++#define MODULE_PROC_FAMILY "STEAMROLLER " ++#elif defined CONFIG_MJAGUAR ++#define MODULE_PROC_FAMILY "JAGUAR " ++#elif defined CONFIG_MEXCAVATOR ++#define MODULE_PROC_FAMILY "EXCAVATOR " ++#elif defined CONFIG_MZEN ++#define MODULE_PROC_FAMILY "ZEN " ++#elif defined CONFIG_MZEN2 ++#define MODULE_PROC_FAMILY "ZEN2 " ++#elif defined CONFIG_MZEN3 ++#define MODULE_PROC_FAMILY "ZEN3 " ++#elif defined CONFIG_MZEN4 ++#define MODULE_PROC_FAMILY "ZEN4 " ++#elif defined CONFIG_MZEN5 ++#define MODULE_PROC_FAMILY "ZEN5 " + #elif defined CONFIG_MELAN + #define MODULE_PROC_FAMILY "ELAN " + #elif defined CONFIG_MCRUSOE +diff --git a/arch/x86/pci/common.c b/arch/x86/pci/common.c +index ddb798603201..7c20387d8202 100644 +--- a/arch/x86/pci/common.c ++++ b/arch/x86/pci/common.c +@@ -723,12 +723,15 @@ int pci_ext_cfg_avail(void) + return 0; + } + +-#if IS_ENABLED(CONFIG_VMD) + struct pci_dev *pci_real_dma_dev(struct pci_dev *dev) + { ++#if IS_ENABLED(CONFIG_VMD) + if (is_vmd(dev->bus)) + return to_pci_sysdata(dev->bus)->vmd_dev; ++#endif ++ ++ if (is_nvme_remap(dev->bus)) ++ return to_pci_sysdata(dev->bus)->nvme_remap_dev; + + return dev; + } +-#endif +diff --git a/arch/x86/xen/setup.c b/arch/x86/xen/setup.c +index 380591028cb8..f6f5aa569367 100644 +--- a/arch/x86/xen/setup.c ++++ b/arch/x86/xen/setup.c +@@ -691,6 +691,7 @@ char * __init xen_memory_setup(void) + struct xen_memory_map memmap; + unsigned long max_pages; + unsigned long extra_pages = 0; ++ unsigned long maxmem_pages; + int i; + int op; + +@@ -762,8 +763,8 @@ char * __init xen_memory_setup(void) + * Make sure we have no memory above max_pages, as this area + * isn't handled by the p2m management. + */ +- extra_pages = min3(EXTRA_MEM_RATIO * min(max_pfn, PFN_DOWN(MAXMEM)), +- extra_pages, max_pages - max_pfn); ++ maxmem_pages = EXTRA_MEM_RATIO * min(max_pfn, PFN_DOWN(MAXMEM)); ++ extra_pages = min3(maxmem_pages, extra_pages, max_pages - max_pfn); + i = 0; + addr = xen_e820_table.entries[0].addr; + size = xen_e820_table.entries[0].size; +diff --git a/block/bfq-iosched.c b/block/bfq-iosched.c +index 88df08a246fa..deecce63d0fc 100644 +--- a/block/bfq-iosched.c ++++ b/block/bfq-iosched.c +@@ -7703,6 +7703,7 @@ MODULE_ALIAS("bfq-iosched"); + static int __init bfq_init(void) + { + int ret; ++ char msg[60] = "BFQ I/O-scheduler: BFQ-CachyOS v6.10"; + + #ifdef CONFIG_BFQ_GROUP_IOSCHED + ret = blkcg_policy_register(&blkcg_policy_bfq); +@@ -7734,6 +7735,11 @@ static int __init bfq_init(void) + if (ret) + goto slab_kill; + ++#ifdef CONFIG_BFQ_GROUP_IOSCHED ++ strcat(msg, " (with cgroups support)"); ++#endif ++ pr_info("%s", msg); ++ + return 0; + + slab_kill: +diff --git a/block/elevator.c b/block/elevator.c +index f64ebd726e58..4f1ccf8cf250 100644 +--- a/block/elevator.c ++++ b/block/elevator.c +@@ -567,9 +567,19 @@ static struct elevator_type *elevator_get_default(struct request_queue *q) + + if (q->nr_hw_queues != 1 && + !blk_mq_is_shared_tags(q->tag_set->flags)) ++#if defined(CONFIG_CACHY) && defined(CONFIG_MQ_IOSCHED_KYBER) ++ return elevator_find_get(q, "kyber"); ++#elif defined(CONFIG_CACHY) ++ return elevator_find_get(q, "mq-deadline"); ++#else + return NULL; ++#endif + ++#if defined(CONFIG_CACHY) && defined(CONFIG_IOSCHED_BFQ) ++ return elevator_find_get(q, "bfq"); ++#else + return elevator_find_get(q, "mq-deadline"); ++#endif + } + + /* +diff --git a/drivers/Makefile b/drivers/Makefile +index fe9ceb0d2288..b58955caf19b 100644 +--- a/drivers/Makefile ++++ b/drivers/Makefile +@@ -61,14 +61,8 @@ obj-y += char/ + # iommu/ comes before gpu as gpu are using iommu controllers + obj-y += iommu/ + +-# gpu/ comes after char for AGP vs DRM startup and after iommu +-obj-y += gpu/ +- + obj-$(CONFIG_CONNECTOR) += connector/ + +-# i810fb depends on char/agp/ +-obj-$(CONFIG_FB_I810) += video/fbdev/i810/ +- + obj-$(CONFIG_PARPORT) += parport/ + obj-y += base/ block/ misc/ mfd/ nfc/ + obj-$(CONFIG_LIBNVDIMM) += nvdimm/ +@@ -80,6 +74,13 @@ obj-y += macintosh/ + obj-y += scsi/ + obj-y += nvme/ + obj-$(CONFIG_ATA) += ata/ ++ ++# gpu/ comes after char for AGP vs DRM startup and after iommu ++obj-y += gpu/ ++ ++# i810fb depends on char/agp/ ++obj-$(CONFIG_FB_I810) += video/fbdev/i810/ ++ + obj-$(CONFIG_TARGET_CORE) += target/ + obj-$(CONFIG_MTD) += mtd/ + obj-$(CONFIG_SPI) += spi/ +diff --git a/drivers/ata/ahci.c b/drivers/ata/ahci.c +index fc6fd583faf8..f79e205a51dd 100644 +--- a/drivers/ata/ahci.c ++++ b/drivers/ata/ahci.c +@@ -1618,7 +1618,7 @@ static irqreturn_t ahci_thunderx_irq_handler(int irq, void *dev_instance) + } + #endif + +-static void ahci_remap_check(struct pci_dev *pdev, int bar, ++static int ahci_remap_check(struct pci_dev *pdev, int bar, + struct ahci_host_priv *hpriv) + { + int i; +@@ -1631,7 +1631,7 @@ static void ahci_remap_check(struct pci_dev *pdev, int bar, + pci_resource_len(pdev, bar) < SZ_512K || + bar != AHCI_PCI_BAR_STANDARD || + !(readl(hpriv->mmio + AHCI_VSCAP) & 1)) +- return; ++ return 0; + + cap = readq(hpriv->mmio + AHCI_REMAP_CAP); + for (i = 0; i < AHCI_MAX_REMAP; i++) { +@@ -1646,18 +1646,11 @@ static void ahci_remap_check(struct pci_dev *pdev, int bar, + } + + if (!hpriv->remapped_nvme) +- return; +- +- dev_warn(&pdev->dev, "Found %u remapped NVMe devices.\n", +- hpriv->remapped_nvme); +- dev_warn(&pdev->dev, +- "Switch your BIOS from RAID to AHCI mode to use them.\n"); ++ return 0; + +- /* +- * Don't rely on the msi-x capability in the remap case, +- * share the legacy interrupt across ahci and remapped devices. +- */ +- hpriv->flags |= AHCI_HFLAG_NO_MSI; ++ /* Abort probe, allowing intel-nvme-remap to step in when available */ ++ dev_info(&pdev->dev, "Device will be handled by intel-nvme-remap.\n"); ++ return -ENODEV; + } + + static int ahci_get_irq_vector(struct ata_host *host, int port) +@@ -1894,7 +1887,9 @@ static int ahci_init_one(struct pci_dev *pdev, const struct pci_device_id *ent) + hpriv->mmio = pcim_iomap_table(pdev)[ahci_pci_bar]; + + /* detect remapped nvme devices */ +- ahci_remap_check(pdev, ahci_pci_bar, hpriv); ++ rc = ahci_remap_check(pdev, ahci_pci_bar, hpriv); ++ if (rc) ++ return rc; + + sysfs_add_file_to_group(&pdev->dev.kobj, + &dev_attr_remapped_nvme.attr, +diff --git a/drivers/cpufreq/Kconfig.x86 b/drivers/cpufreq/Kconfig.x86 +index 97c2d4f15d76..5a3af44d785a 100644 +--- a/drivers/cpufreq/Kconfig.x86 ++++ b/drivers/cpufreq/Kconfig.x86 +@@ -9,7 +9,6 @@ config X86_INTEL_PSTATE + select ACPI_PROCESSOR if ACPI + select ACPI_CPPC_LIB if X86_64 && ACPI && SCHED_MC_PRIO + select CPU_FREQ_GOV_PERFORMANCE +- select CPU_FREQ_GOV_SCHEDUTIL if SMP + help + This driver provides a P state for Intel core processors. + The driver implements an internal governor and will become +@@ -39,7 +38,6 @@ config X86_AMD_PSTATE + depends on X86 && ACPI + select ACPI_PROCESSOR + select ACPI_CPPC_LIB if X86_64 +- select CPU_FREQ_GOV_SCHEDUTIL if SMP + help + This driver adds a CPUFreq driver which utilizes a fine grain + processor performance frequency control range instead of legacy +diff --git a/drivers/cpufreq/cpufreq.c b/drivers/cpufreq/cpufreq.c +index 270ea04fb616..d7ae2a6bcd7b 100644 +--- a/drivers/cpufreq/cpufreq.c ++++ b/drivers/cpufreq/cpufreq.c +@@ -575,30 +575,11 @@ unsigned int cpufreq_policy_transition_delay_us(struct cpufreq_policy *policy) + return policy->transition_delay_us; + + latency = policy->cpuinfo.transition_latency / NSEC_PER_USEC; +- if (latency) { +- unsigned int max_delay_us = 2 * MSEC_PER_SEC; ++ if (latency) ++ /* Give a 50% breathing room between updates */ ++ return latency + (latency >> 1); + +- /* +- * If the platform already has high transition_latency, use it +- * as-is. +- */ +- if (latency > max_delay_us) +- return latency; +- +- /* +- * For platforms that can change the frequency very fast (< 2 +- * us), the above formula gives a decent transition delay. But +- * for platforms where transition_latency is in milliseconds, it +- * ends up giving unrealistic values. +- * +- * Cap the default transition delay to 2 ms, which seems to be +- * a reasonable amount of time after which we should reevaluate +- * the frequency. +- */ +- return min(latency * LATENCY_MULTIPLIER, max_delay_us); +- } +- +- return LATENCY_MULTIPLIER; ++ return USEC_PER_MSEC; + } + EXPORT_SYMBOL_GPL(cpufreq_policy_transition_delay_us); + +diff --git a/drivers/cpufreq/intel_pstate.c b/drivers/cpufreq/intel_pstate.c +index c31914a9876f..1035c074f36a 100644 +--- a/drivers/cpufreq/intel_pstate.c ++++ b/drivers/cpufreq/intel_pstate.c +@@ -3550,6 +3550,8 @@ static int __init intel_pstate_setup(char *str) + + if (!strcmp(str, "disable")) + no_load = 1; ++ else if (!strcmp(str, "enable")) ++ no_load = 0; + else if (!strcmp(str, "active")) + default_driver = &intel_pstate; + else if (!strcmp(str, "passive")) +diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu.h b/drivers/gpu/drm/amd/amdgpu/amdgpu.h +index f87d53e183c3..c489d3b2576b 100644 +--- a/drivers/gpu/drm/amd/amdgpu/amdgpu.h ++++ b/drivers/gpu/drm/amd/amdgpu/amdgpu.h +@@ -159,6 +159,7 @@ struct amdgpu_watchdog_timer { + */ + extern int amdgpu_modeset; + extern unsigned int amdgpu_vram_limit; ++extern int amdgpu_ignore_min_pcap; + extern int amdgpu_vis_vram_limit; + extern int amdgpu_gart_size; + extern int amdgpu_gtt_size; +diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +index ea14f1c8f430..bb0b636d0d75 100644 +--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c ++++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +@@ -132,6 +132,7 @@ enum AMDGPU_DEBUG_MASK { + }; + + unsigned int amdgpu_vram_limit = UINT_MAX; ++int amdgpu_ignore_min_pcap = 0; /* do not ignore by default */ + int amdgpu_vis_vram_limit; + int amdgpu_gart_size = -1; /* auto */ + int amdgpu_gtt_size = -1; /* auto */ +@@ -243,6 +244,15 @@ struct amdgpu_watchdog_timer amdgpu_watchdog_timer = { + .period = 0x0, /* default to 0x0 (timeout disable) */ + }; + ++/** ++ * DOC: ignore_min_pcap (int) ++ * Ignore the minimum power cap. ++ * Useful on graphics cards where the minimum power cap is very high. ++ * The default is 0 (Do not ignore). ++ */ ++MODULE_PARM_DESC(ignore_min_pcap, "Ignore the minimum power cap"); ++module_param_named(ignore_min_pcap, amdgpu_ignore_min_pcap, int, 0600); ++ + /** + * DOC: vramlimit (int) + * Restrict the total amount of VRAM in MiB for testing. The default is 0 (Use full VRAM). +diff --git a/drivers/gpu/drm/amd/display/Kconfig b/drivers/gpu/drm/amd/display/Kconfig +index 47b8b49da8a7..943959d1f401 100644 +--- a/drivers/gpu/drm/amd/display/Kconfig ++++ b/drivers/gpu/drm/amd/display/Kconfig +@@ -51,4 +51,10 @@ config DRM_AMD_SECURE_DISPLAY + This option enables the calculation of crc of specific region via + debugfs. Cooperate with specific DMCU FW. + ++config AMD_PRIVATE_COLOR ++ bool "Enable KMS color management by AMD for AMD" ++ default n ++ help ++ This option extends the KMS color management API with AMD driver-specific properties to enhance the color management support on AMD Steam Deck. ++ + endmenu +diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c +index 964bb6d0a383..a5f10700c16b 100644 +--- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c ++++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c +@@ -4119,7 +4119,7 @@ static int amdgpu_dm_mode_config_init(struct amdgpu_device *adev) + return r; + } + +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + if (amdgpu_dm_create_color_properties(adev)) + return -ENOMEM; + #endif +diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c +index ebabfe3a512f..4d3ebcaacca1 100644 +--- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c ++++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c +@@ -97,7 +97,7 @@ static inline struct fixed31_32 amdgpu_dm_fixpt_from_s3132(__u64 x) + return val; + } + +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + /* Pre-defined Transfer Functions (TF) + * + * AMD driver supports pre-defined mathematical functions for transferring +diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c +index e23a0a276e33..dd83cf50a89b 100644 +--- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c ++++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c +@@ -338,7 +338,7 @@ static int amdgpu_dm_crtc_late_register(struct drm_crtc *crtc) + } + #endif + +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + /** + * dm_crtc_additional_color_mgmt - enable additional color properties + * @crtc: DRM CRTC +@@ -420,7 +420,7 @@ static const struct drm_crtc_funcs amdgpu_dm_crtc_funcs = { + #if defined(CONFIG_DEBUG_FS) + .late_register = amdgpu_dm_crtc_late_register, + #endif +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + .atomic_set_property = amdgpu_dm_atomic_crtc_set_property, + .atomic_get_property = amdgpu_dm_atomic_crtc_get_property, + #endif +@@ -599,7 +599,7 @@ int amdgpu_dm_crtc_init(struct amdgpu_display_manager *dm, + + drm_mode_crtc_set_gamma_size(&acrtc->base, MAX_COLOR_LEGACY_LUT_ENTRIES); + +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + dm_crtc_additional_color_mgmt(&acrtc->base); + #endif + return 0; +diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c +index 8a4c40b4c27e..779880c64575 100644 +--- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c ++++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c +@@ -1468,7 +1468,7 @@ static void amdgpu_dm_plane_drm_plane_destroy_state(struct drm_plane *plane, + drm_atomic_helper_plane_destroy_state(plane, state); + } + +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + static void + dm_atomic_plane_attach_color_mgmt_properties(struct amdgpu_display_manager *dm, + struct drm_plane *plane) +@@ -1659,7 +1659,7 @@ static const struct drm_plane_funcs dm_plane_funcs = { + .atomic_duplicate_state = amdgpu_dm_plane_drm_plane_duplicate_state, + .atomic_destroy_state = amdgpu_dm_plane_drm_plane_destroy_state, + .format_mod_supported = amdgpu_dm_plane_format_mod_supported, +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + .atomic_set_property = dm_atomic_plane_set_property, + .atomic_get_property = dm_atomic_plane_get_property, + #endif +@@ -1742,7 +1742,7 @@ int amdgpu_dm_plane_init(struct amdgpu_display_manager *dm, + + drm_plane_helper_add(plane, &dm_plane_helper_funcs); + +-#ifdef AMD_PRIVATE_COLOR ++#ifdef CONFIG_AMD_PRIVATE_COLOR + dm_atomic_plane_attach_color_mgmt_properties(dm, plane); + #endif + /* Create (reset) the plane state */ +diff --git a/drivers/gpu/drm/amd/pm/amdgpu_pm.c b/drivers/gpu/drm/amd/pm/amdgpu_pm.c +index c11952a4389b..52f54a228b39 100644 +--- a/drivers/gpu/drm/amd/pm/amdgpu_pm.c ++++ b/drivers/gpu/drm/amd/pm/amdgpu_pm.c +@@ -3155,6 +3155,9 @@ static ssize_t amdgpu_hwmon_show_power_cap_min(struct device *dev, + struct device_attribute *attr, + char *buf) + { ++ if (amdgpu_ignore_min_pcap) ++ return sysfs_emit(buf, "%i\n", 0); ++ + return amdgpu_hwmon_show_power_cap_generic(dev, attr, buf, PP_PWR_LIMIT_MIN); + } + +diff --git a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c +index 06409133b09b..9335adb556ce 100644 +--- a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c ++++ b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c +@@ -2749,7 +2749,10 @@ int smu_get_power_limit(void *handle, + *limit = smu->max_power_limit; + break; + case SMU_PPT_LIMIT_MIN: +- *limit = smu->min_power_limit; ++ if (amdgpu_ignore_min_pcap) ++ *limit = 0; ++ else ++ *limit = smu->min_power_limit; + break; + default: + return -EINVAL; +@@ -2773,7 +2776,14 @@ static int smu_set_power_limit(void *handle, uint32_t limit) + if (smu->ppt_funcs->set_power_limit) + return smu->ppt_funcs->set_power_limit(smu, limit_type, limit); + +- if ((limit > smu->max_power_limit) || (limit < smu->min_power_limit)) { ++ if (amdgpu_ignore_min_pcap) { ++ if ((limit > smu->max_power_limit)) { ++ dev_err(smu->adev->dev, ++ "New power limit (%d) is over the max allowed %d\n", ++ limit, smu->max_power_limit); ++ return -EINVAL; ++ } ++ } else if ((limit > smu->max_power_limit) || (limit < smu->min_power_limit)) { + dev_err(smu->adev->dev, + "New power limit (%d) is out of range [%d,%d]\n", + limit, smu->min_power_limit, smu->max_power_limit); +diff --git a/drivers/i2c/busses/Kconfig b/drivers/i2c/busses/Kconfig +index fe6e8a1bb607..1488a904e3bf 100644 +--- a/drivers/i2c/busses/Kconfig ++++ b/drivers/i2c/busses/Kconfig +@@ -238,6 +238,15 @@ config I2C_CHT_WC + combined with a FUSB302 Type-C port-controller as such it is advised + to also select CONFIG_TYPEC_FUSB302=m. + ++config I2C_NCT6775 ++ tristate "Nuvoton NCT6775 and compatible SMBus controller" ++ help ++ If you say yes to this option, support will be included for the ++ Nuvoton NCT6775 and compatible SMBus controllers. ++ ++ This driver can also be built as a module. If so, the module ++ will be called i2c-nct6775. ++ + config I2C_NFORCE2 + tristate "Nvidia nForce2, nForce3 and nForce4" + depends on PCI && HAS_IOPORT +diff --git a/drivers/i2c/busses/Makefile b/drivers/i2c/busses/Makefile +index 78d0561339e5..9ea3a294f9f0 100644 +--- a/drivers/i2c/busses/Makefile ++++ b/drivers/i2c/busses/Makefile +@@ -20,6 +20,7 @@ obj-$(CONFIG_I2C_CHT_WC) += i2c-cht-wc.o + obj-$(CONFIG_I2C_I801) += i2c-i801.o + obj-$(CONFIG_I2C_ISCH) += i2c-isch.o + obj-$(CONFIG_I2C_ISMT) += i2c-ismt.o ++obj-$(CONFIG_I2C_NCT6775) += i2c-nct6775.o + obj-$(CONFIG_I2C_NFORCE2) += i2c-nforce2.o + obj-$(CONFIG_I2C_NFORCE2_S4985) += i2c-nforce2-s4985.o + obj-$(CONFIG_I2C_NVIDIA_GPU) += i2c-nvidia-gpu.o +diff --git a/drivers/i2c/busses/i2c-nct6775.c b/drivers/i2c/busses/i2c-nct6775.c +new file mode 100644 +index 000000000000..fdbd9a1c8d7a +--- /dev/null ++++ b/drivers/i2c/busses/i2c-nct6775.c +@@ -0,0 +1,648 @@ ++/* ++ * i2c-nct6775 - Driver for the SMBus master functionality of ++ * Nuvoton NCT677x Super-I/O chips ++ * ++ * Copyright (C) 2019 Adam Honse ++ * ++ * Derived from nct6775 hwmon driver ++ * Copyright (C) 2012 Guenter Roeck ++ * ++ * This program is free software; you can redistribute it and/or modify ++ * it under the terms of the GNU General Public License as published by ++ * the Free Software Foundation; either version 2 of the License, or ++ * (at your option) any later version. ++ * ++ * This program is distributed in the hope that it will be useful, ++ * but WITHOUT ANY WARRANTY; without even the implied warranty of ++ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ++ * GNU General Public License for more details. ++ * ++ * You should have received a copy of the GNU General Public License ++ * along with this program; if not, write to the Free Software ++ * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. ++ * ++ */ ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#define DRVNAME "i2c-nct6775" ++ ++/* Nuvoton SMBus address offsets */ ++#define SMBHSTDAT (0 + nuvoton_nct6793d_smba) ++#define SMBBLKSZ (1 + nuvoton_nct6793d_smba) ++#define SMBHSTCMD (2 + nuvoton_nct6793d_smba) ++#define SMBHSTIDX (3 + nuvoton_nct6793d_smba) //Index field is the Command field on other controllers ++#define SMBHSTCTL (4 + nuvoton_nct6793d_smba) ++#define SMBHSTADD (5 + nuvoton_nct6793d_smba) ++#define SMBHSTERR (9 + nuvoton_nct6793d_smba) ++#define SMBHSTSTS (0xE + nuvoton_nct6793d_smba) ++ ++/* Command register */ ++#define NCT6793D_READ_BYTE 0 ++#define NCT6793D_READ_WORD 1 ++#define NCT6793D_READ_BLOCK 2 ++#define NCT6793D_BLOCK_WRITE_READ_PROC_CALL 3 ++#define NCT6793D_PROC_CALL 4 ++#define NCT6793D_WRITE_BYTE 8 ++#define NCT6793D_WRITE_WORD 9 ++#define NCT6793D_WRITE_BLOCK 10 ++ ++/* Control register */ ++#define NCT6793D_MANUAL_START 128 ++#define NCT6793D_SOFT_RESET 64 ++ ++/* Error register */ ++#define NCT6793D_NO_ACK 32 ++ ++/* Status register */ ++#define NCT6793D_FIFO_EMPTY 1 ++#define NCT6793D_FIFO_FULL 2 ++#define NCT6793D_MANUAL_ACTIVE 4 ++ ++#define NCT6775_LD_SMBUS 0x0B ++ ++/* Other settings */ ++#define MAX_RETRIES 400 ++ ++enum kinds { nct6106, nct6775, nct6776, nct6779, nct6791, nct6792, nct6793, ++ nct6795, nct6796, nct6798 }; ++ ++struct nct6775_sio_data { ++ int sioreg; ++ enum kinds kind; ++}; ++ ++/* used to set data->name = nct6775_device_names[data->sio_kind] */ ++static const char * const nct6775_device_names[] = { ++ "nct6106", ++ "nct6775", ++ "nct6776", ++ "nct6779", ++ "nct6791", ++ "nct6792", ++ "nct6793", ++ "nct6795", ++ "nct6796", ++ "nct6798", ++}; ++ ++static const char * const nct6775_sio_names[] __initconst = { ++ "NCT6106D", ++ "NCT6775F", ++ "NCT6776D/F", ++ "NCT6779D", ++ "NCT6791D", ++ "NCT6792D", ++ "NCT6793D", ++ "NCT6795D", ++ "NCT6796D", ++ "NCT6798D", ++}; ++ ++#define SIO_REG_LDSEL 0x07 /* Logical device select */ ++#define SIO_REG_DEVID 0x20 /* Device ID (2 bytes) */ ++#define SIO_REG_SMBA 0x62 /* SMBus base address register */ ++ ++#define SIO_NCT6106_ID 0xc450 ++#define SIO_NCT6775_ID 0xb470 ++#define SIO_NCT6776_ID 0xc330 ++#define SIO_NCT6779_ID 0xc560 ++#define SIO_NCT6791_ID 0xc800 ++#define SIO_NCT6792_ID 0xc910 ++#define SIO_NCT6793_ID 0xd120 ++#define SIO_NCT6795_ID 0xd350 ++#define SIO_NCT6796_ID 0xd420 ++#define SIO_NCT6798_ID 0xd428 ++#define SIO_ID_MASK 0xFFF0 ++ ++static inline void ++superio_outb(int ioreg, int reg, int val) ++{ ++ outb(reg, ioreg); ++ outb(val, ioreg + 1); ++} ++ ++static inline int ++superio_inb(int ioreg, int reg) ++{ ++ outb(reg, ioreg); ++ return inb(ioreg + 1); ++} ++ ++static inline void ++superio_select(int ioreg, int ld) ++{ ++ outb(SIO_REG_LDSEL, ioreg); ++ outb(ld, ioreg + 1); ++} ++ ++static inline int ++superio_enter(int ioreg) ++{ ++ /* ++ * Try to reserve and for exclusive access. ++ */ ++ if (!request_muxed_region(ioreg, 2, DRVNAME)) ++ return -EBUSY; ++ ++ outb(0x87, ioreg); ++ outb(0x87, ioreg); ++ ++ return 0; ++} ++ ++static inline void ++superio_exit(int ioreg) ++{ ++ outb(0xaa, ioreg); ++ outb(0x02, ioreg); ++ outb(0x02, ioreg + 1); ++ release_region(ioreg, 2); ++} ++ ++/* ++ * ISA constants ++ */ ++ ++#define IOREGION_ALIGNMENT (~7) ++#define IOREGION_LENGTH 2 ++#define ADDR_REG_OFFSET 0 ++#define DATA_REG_OFFSET 1 ++ ++#define NCT6775_REG_BANK 0x4E ++#define NCT6775_REG_CONFIG 0x40 ++ ++static struct i2c_adapter *nct6775_adapter; ++ ++struct i2c_nct6775_adapdata { ++ unsigned short smba; ++}; ++ ++/* Return negative errno on error. */ ++static s32 nct6775_access(struct i2c_adapter * adap, u16 addr, ++ unsigned short flags, char read_write, ++ u8 command, int size, union i2c_smbus_data * data) ++{ ++ struct i2c_nct6775_adapdata *adapdata = i2c_get_adapdata(adap); ++ unsigned short nuvoton_nct6793d_smba = adapdata->smba; ++ int i, len, cnt; ++ union i2c_smbus_data tmp_data; ++ int timeout = 0; ++ ++ tmp_data.word = 0; ++ cnt = 0; ++ len = 0; ++ ++ outb_p(NCT6793D_SOFT_RESET, SMBHSTCTL); ++ ++ switch (size) { ++ case I2C_SMBUS_QUICK: ++ outb_p((addr << 1) | read_write, ++ SMBHSTADD); ++ break; ++ case I2C_SMBUS_BYTE_DATA: ++ tmp_data.byte = data->byte; ++ fallthrough; ++ case I2C_SMBUS_BYTE: ++ outb_p((addr << 1) | read_write, ++ SMBHSTADD); ++ outb_p(command, SMBHSTIDX); ++ if (read_write == I2C_SMBUS_WRITE) { ++ outb_p(tmp_data.byte, SMBHSTDAT); ++ outb_p(NCT6793D_WRITE_BYTE, SMBHSTCMD); ++ } ++ else { ++ outb_p(NCT6793D_READ_BYTE, SMBHSTCMD); ++ } ++ break; ++ case I2C_SMBUS_WORD_DATA: ++ outb_p((addr << 1) | read_write, ++ SMBHSTADD); ++ outb_p(command, SMBHSTIDX); ++ if (read_write == I2C_SMBUS_WRITE) { ++ outb_p(data->word & 0xff, SMBHSTDAT); ++ outb_p((data->word & 0xff00) >> 8, SMBHSTDAT); ++ outb_p(NCT6793D_WRITE_WORD, SMBHSTCMD); ++ } ++ else { ++ outb_p(NCT6793D_READ_WORD, SMBHSTCMD); ++ } ++ break; ++ case I2C_SMBUS_BLOCK_DATA: ++ outb_p((addr << 1) | read_write, ++ SMBHSTADD); ++ outb_p(command, SMBHSTIDX); ++ if (read_write == I2C_SMBUS_WRITE) { ++ len = data->block[0]; ++ if (len == 0 || len > I2C_SMBUS_BLOCK_MAX) ++ return -EINVAL; ++ outb_p(len, SMBBLKSZ); ++ ++ cnt = 1; ++ if (len >= 4) { ++ for (i = cnt; i <= 4; i++) { ++ outb_p(data->block[i], SMBHSTDAT); ++ } ++ ++ len -= 4; ++ cnt += 4; ++ } ++ else { ++ for (i = cnt; i <= len; i++ ) { ++ outb_p(data->block[i], SMBHSTDAT); ++ } ++ ++ len = 0; ++ } ++ ++ outb_p(NCT6793D_WRITE_BLOCK, SMBHSTCMD); ++ } ++ else { ++ return -ENOTSUPP; ++ } ++ break; ++ default: ++ dev_warn(&adap->dev, "Unsupported transaction %d\n", size); ++ return -EOPNOTSUPP; ++ } ++ ++ outb_p(NCT6793D_MANUAL_START, SMBHSTCTL); ++ ++ while ((size == I2C_SMBUS_BLOCK_DATA) && (len > 0)) { ++ if (read_write == I2C_SMBUS_WRITE) { ++ timeout = 0; ++ while ((inb_p(SMBHSTSTS) & NCT6793D_FIFO_EMPTY) == 0) ++ { ++ if(timeout > MAX_RETRIES) ++ { ++ return -ETIMEDOUT; ++ } ++ usleep_range(250, 500); ++ timeout++; ++ } ++ ++ //Load more bytes into FIFO ++ if (len >= 4) { ++ for (i = cnt; i <= (cnt + 4); i++) { ++ outb_p(data->block[i], SMBHSTDAT); ++ } ++ ++ len -= 4; ++ cnt += 4; ++ } ++ else { ++ for (i = cnt; i <= (cnt + len); i++) { ++ outb_p(data->block[i], SMBHSTDAT); ++ } ++ ++ len = 0; ++ } ++ } ++ else { ++ return -ENOTSUPP; ++ } ++ ++ } ++ ++ //wait for manual mode to complete ++ timeout = 0; ++ while ((inb_p(SMBHSTSTS) & NCT6793D_MANUAL_ACTIVE) != 0) ++ { ++ if(timeout > MAX_RETRIES) ++ { ++ return -ETIMEDOUT; ++ } ++ usleep_range(250, 500); ++ timeout++; ++ } ++ ++ if ((inb_p(SMBHSTERR) & NCT6793D_NO_ACK) != 0) { ++ return -ENXIO; ++ } ++ else if ((read_write == I2C_SMBUS_WRITE) || (size == I2C_SMBUS_QUICK)) { ++ return 0; ++ } ++ ++ switch (size) { ++ case I2C_SMBUS_QUICK: ++ case I2C_SMBUS_BYTE_DATA: ++ data->byte = inb_p(SMBHSTDAT); ++ break; ++ case I2C_SMBUS_WORD_DATA: ++ data->word = inb_p(SMBHSTDAT) + (inb_p(SMBHSTDAT) << 8); ++ break; ++ } ++ return 0; ++} ++ ++static u32 nct6775_func(struct i2c_adapter *adapter) ++{ ++ return I2C_FUNC_SMBUS_QUICK | I2C_FUNC_SMBUS_BYTE | ++ I2C_FUNC_SMBUS_BYTE_DATA | I2C_FUNC_SMBUS_WORD_DATA | ++ I2C_FUNC_SMBUS_BLOCK_DATA; ++} ++ ++static const struct i2c_algorithm smbus_algorithm = { ++ .smbus_xfer = nct6775_access, ++ .functionality = nct6775_func, ++}; ++ ++static int nct6775_add_adapter(unsigned short smba, const char *name, struct i2c_adapter **padap) ++{ ++ struct i2c_adapter *adap; ++ struct i2c_nct6775_adapdata *adapdata; ++ int retval; ++ ++ adap = kzalloc(sizeof(*adap), GFP_KERNEL); ++ if (adap == NULL) { ++ return -ENOMEM; ++ } ++ ++ adap->owner = THIS_MODULE; ++ adap->class = I2C_CLASS_HWMON; ++ adap->algo = &smbus_algorithm; ++ ++ adapdata = kzalloc(sizeof(*adapdata), GFP_KERNEL); ++ if (adapdata == NULL) { ++ kfree(adap); ++ return -ENOMEM; ++ } ++ ++ adapdata->smba = smba; ++ ++ snprintf(adap->name, sizeof(adap->name), ++ "SMBus NCT67xx adapter%s at %04x", name, smba); ++ ++ i2c_set_adapdata(adap, adapdata); ++ ++ retval = i2c_add_adapter(adap); ++ if (retval) { ++ kfree(adapdata); ++ kfree(adap); ++ return retval; ++ } ++ ++ *padap = adap; ++ return 0; ++} ++ ++static void nct6775_remove_adapter(struct i2c_adapter *adap) ++{ ++ struct i2c_nct6775_adapdata *adapdata = i2c_get_adapdata(adap); ++ ++ if (adapdata->smba) { ++ i2c_del_adapter(adap); ++ kfree(adapdata); ++ kfree(adap); ++ } ++} ++ ++//static SIMPLE_DEV_PM_OPS(nct6775_dev_pm_ops, nct6775_suspend, nct6775_resume); ++ ++/* ++ * when Super-I/O functions move to a separate file, the Super-I/O ++ * bus will manage the lifetime of the device and this module will only keep ++ * track of the nct6775 driver. But since we use platform_device_alloc(), we ++ * must keep track of the device ++ */ ++static struct platform_device *pdev[2]; ++ ++static int nct6775_probe(struct platform_device *pdev) ++{ ++ struct device *dev = &pdev->dev; ++ struct nct6775_sio_data *sio_data = dev_get_platdata(dev); ++ struct resource *res; ++ ++ res = platform_get_resource(pdev, IORESOURCE_IO, 0); ++ if (!devm_request_region(&pdev->dev, res->start, IOREGION_LENGTH, ++ DRVNAME)) ++ return -EBUSY; ++ ++ switch (sio_data->kind) { ++ case nct6791: ++ case nct6792: ++ case nct6793: ++ case nct6795: ++ case nct6796: ++ case nct6798: ++ nct6775_add_adapter(res->start, "", &nct6775_adapter); ++ break; ++ default: ++ return -ENODEV; ++ } ++ ++ return 0; ++} ++/* ++static void nct6791_enable_io_mapping(int sioaddr) ++{ ++ int val; ++ ++ val = superio_inb(sioaddr, NCT6791_REG_HM_IO_SPACE_LOCK_ENABLE); ++ if (val & 0x10) { ++ pr_info("Enabling hardware monitor logical device mappings.\n"); ++ superio_outb(sioaddr, NCT6791_REG_HM_IO_SPACE_LOCK_ENABLE, ++ val & ~0x10); ++ } ++}*/ ++ ++static struct platform_driver i2c_nct6775_driver = { ++ .driver = { ++ .name = DRVNAME, ++// .pm = &nct6775_dev_pm_ops, ++ }, ++ .probe = nct6775_probe, ++}; ++ ++static void __exit i2c_nct6775_exit(void) ++{ ++ int i; ++ ++ if(nct6775_adapter) ++ nct6775_remove_adapter(nct6775_adapter); ++ ++ for (i = 0; i < ARRAY_SIZE(pdev); i++) { ++ if (pdev[i]) ++ platform_device_unregister(pdev[i]); ++ } ++ platform_driver_unregister(&i2c_nct6775_driver); ++} ++ ++/* nct6775_find() looks for a '627 in the Super-I/O config space */ ++static int __init nct6775_find(int sioaddr, struct nct6775_sio_data *sio_data) ++{ ++ u16 val; ++ int err; ++ int addr; ++ ++ err = superio_enter(sioaddr); ++ if (err) ++ return err; ++ ++ val = (superio_inb(sioaddr, SIO_REG_DEVID) << 8) | ++ superio_inb(sioaddr, SIO_REG_DEVID + 1); ++ ++ switch (val & SIO_ID_MASK) { ++ case SIO_NCT6106_ID: ++ sio_data->kind = nct6106; ++ break; ++ case SIO_NCT6775_ID: ++ sio_data->kind = nct6775; ++ break; ++ case SIO_NCT6776_ID: ++ sio_data->kind = nct6776; ++ break; ++ case SIO_NCT6779_ID: ++ sio_data->kind = nct6779; ++ break; ++ case SIO_NCT6791_ID: ++ sio_data->kind = nct6791; ++ break; ++ case SIO_NCT6792_ID: ++ sio_data->kind = nct6792; ++ break; ++ case SIO_NCT6793_ID: ++ sio_data->kind = nct6793; ++ break; ++ case SIO_NCT6795_ID: ++ sio_data->kind = nct6795; ++ break; ++ case SIO_NCT6796_ID: ++ sio_data->kind = nct6796; ++ break; ++ case SIO_NCT6798_ID: ++ sio_data->kind = nct6798; ++ break; ++ default: ++ if (val != 0xffff) ++ pr_debug("unsupported chip ID: 0x%04x\n", val); ++ superio_exit(sioaddr); ++ return -ENODEV; ++ } ++ ++ /* We have a known chip, find the SMBus I/O address */ ++ superio_select(sioaddr, NCT6775_LD_SMBUS); ++ val = (superio_inb(sioaddr, SIO_REG_SMBA) << 8) ++ | superio_inb(sioaddr, SIO_REG_SMBA + 1); ++ addr = val & IOREGION_ALIGNMENT; ++ if (addr == 0) { ++ pr_err("Refusing to enable a Super-I/O device with a base I/O port 0\n"); ++ superio_exit(sioaddr); ++ return -ENODEV; ++ } ++ ++ //if (sio_data->kind == nct6791 || sio_data->kind == nct6792 || ++ // sio_data->kind == nct6793 || sio_data->kind == nct6795 || ++ // sio_data->kind == nct6796) ++ // nct6791_enable_io_mapping(sioaddr); ++ ++ superio_exit(sioaddr); ++ pr_info("Found %s or compatible chip at %#x:%#x\n", ++ nct6775_sio_names[sio_data->kind], sioaddr, addr); ++ sio_data->sioreg = sioaddr; ++ ++ return addr; ++} ++ ++static int __init i2c_nct6775_init(void) ++{ ++ int i, err; ++ bool found = false; ++ int address; ++ struct resource res; ++ struct nct6775_sio_data sio_data; ++ int sioaddr[2] = { 0x2e, 0x4e }; ++ ++ err = platform_driver_register(&i2c_nct6775_driver); ++ if (err) ++ return err; ++ ++ /* ++ * initialize sio_data->kind and sio_data->sioreg. ++ * ++ * when Super-I/O functions move to a separate file, the Super-I/O ++ * driver will probe 0x2e and 0x4e and auto-detect the presence of a ++ * nct6775 hardware monitor, and call probe() ++ */ ++ for (i = 0; i < ARRAY_SIZE(pdev); i++) { ++ address = nct6775_find(sioaddr[i], &sio_data); ++ if (address <= 0) ++ continue; ++ ++ found = true; ++ ++ pdev[i] = platform_device_alloc(DRVNAME, address); ++ if (!pdev[i]) { ++ err = -ENOMEM; ++ goto exit_device_unregister; ++ } ++ ++ err = platform_device_add_data(pdev[i], &sio_data, ++ sizeof(struct nct6775_sio_data)); ++ if (err) ++ goto exit_device_put; ++ ++ memset(&res, 0, sizeof(res)); ++ res.name = DRVNAME; ++ res.start = address; ++ res.end = address + IOREGION_LENGTH - 1; ++ res.flags = IORESOURCE_IO; ++ ++ err = acpi_check_resource_conflict(&res); ++ if (err) { ++ platform_device_put(pdev[i]); ++ pdev[i] = NULL; ++ continue; ++ } ++ ++ err = platform_device_add_resources(pdev[i], &res, 1); ++ if (err) ++ goto exit_device_put; ++ ++ /* platform_device_add calls probe() */ ++ err = platform_device_add(pdev[i]); ++ if (err) ++ goto exit_device_put; ++ } ++ if (!found) { ++ err = -ENODEV; ++ goto exit_unregister; ++ } ++ ++ return 0; ++ ++exit_device_put: ++ platform_device_put(pdev[i]); ++exit_device_unregister: ++ while (--i >= 0) { ++ if (pdev[i]) ++ platform_device_unregister(pdev[i]); ++ } ++exit_unregister: ++ platform_driver_unregister(&i2c_nct6775_driver); ++ return err; ++} ++ ++MODULE_AUTHOR("Adam Honse "); ++MODULE_DESCRIPTION("SMBus driver for NCT6775F and compatible chips"); ++MODULE_LICENSE("GPL"); ++ ++module_init(i2c_nct6775_init); ++module_exit(i2c_nct6775_exit); +diff --git a/drivers/i2c/busses/i2c-piix4.c b/drivers/i2c/busses/i2c-piix4.c +index 6a0392172b2f..e7dd007bf6b1 100644 +--- a/drivers/i2c/busses/i2c-piix4.c ++++ b/drivers/i2c/busses/i2c-piix4.c +@@ -568,11 +568,11 @@ static int piix4_transaction(struct i2c_adapter *piix4_adapter) + if (srvrworks_csb5_delay) /* Extra delay for SERVERWORKS_CSB5 */ + usleep_range(2000, 2100); + else +- usleep_range(250, 500); ++ usleep_range(25, 50); + + while ((++timeout < MAX_TIMEOUT) && + ((temp = inb_p(SMBHSTSTS)) & 0x01)) +- usleep_range(250, 500); ++ usleep_range(25, 50); + + /* If the SMBus is still busy, we give up */ + if (timeout == MAX_TIMEOUT) { +diff --git a/drivers/input/evdev.c b/drivers/input/evdev.c +index 51e0c4954600..35c3ad741870 100644 +--- a/drivers/input/evdev.c ++++ b/drivers/input/evdev.c +@@ -46,6 +46,7 @@ struct evdev_client { + struct fasync_struct *fasync; + struct evdev *evdev; + struct list_head node; ++ struct rcu_head rcu; + enum input_clock_type clk_type; + bool revoked; + unsigned long *evmasks[EV_CNT]; +@@ -377,13 +378,22 @@ static void evdev_attach_client(struct evdev *evdev, + spin_unlock(&evdev->client_lock); + } + ++static void evdev_reclaim_client(struct rcu_head *rp) ++{ ++ struct evdev_client *client = container_of(rp, struct evdev_client, rcu); ++ unsigned int i; ++ for (i = 0; i < EV_CNT; ++i) ++ bitmap_free(client->evmasks[i]); ++ kvfree(client); ++} ++ + static void evdev_detach_client(struct evdev *evdev, + struct evdev_client *client) + { + spin_lock(&evdev->client_lock); + list_del_rcu(&client->node); + spin_unlock(&evdev->client_lock); +- synchronize_rcu(); ++ call_rcu(&client->rcu, evdev_reclaim_client); + } + + static int evdev_open_device(struct evdev *evdev) +@@ -436,7 +446,6 @@ static int evdev_release(struct inode *inode, struct file *file) + { + struct evdev_client *client = file->private_data; + struct evdev *evdev = client->evdev; +- unsigned int i; + + mutex_lock(&evdev->mutex); + +@@ -448,11 +457,6 @@ static int evdev_release(struct inode *inode, struct file *file) + + evdev_detach_client(evdev, client); + +- for (i = 0; i < EV_CNT; ++i) +- bitmap_free(client->evmasks[i]); +- +- kvfree(client); +- + evdev_close_device(evdev); + + return 0; +@@ -495,7 +499,6 @@ static int evdev_open(struct inode *inode, struct file *file) + + err_free_client: + evdev_detach_client(evdev, client); +- kvfree(client); + return error; + } + +diff --git a/drivers/md/dm-crypt.c b/drivers/md/dm-crypt.c +index 1b7a97cc3779..37e9e43908ab 100644 +--- a/drivers/md/dm-crypt.c ++++ b/drivers/md/dm-crypt.c +@@ -3284,6 +3284,11 @@ static int crypt_ctr(struct dm_target *ti, unsigned int argc, char **argv) + goto bad; + } + ++#ifdef CONFIG_CACHY ++ set_bit(DM_CRYPT_NO_READ_WORKQUEUE, &cc->flags); ++ set_bit(DM_CRYPT_NO_WRITE_WORKQUEUE, &cc->flags); ++#endif ++ + ret = crypt_ctr_cipher(ti, argv[0], argv[1]); + if (ret < 0) + goto bad; +diff --git a/drivers/media/v4l2-core/Kconfig b/drivers/media/v4l2-core/Kconfig +index 331b8e535e5b..80dabeebf580 100644 +--- a/drivers/media/v4l2-core/Kconfig ++++ b/drivers/media/v4l2-core/Kconfig +@@ -40,6 +40,11 @@ config VIDEO_TUNER + config V4L2_JPEG_HELPER + tristate + ++config V4L2_LOOPBACK ++ tristate "V4L2 loopback device" ++ help ++ V4L2 loopback device ++ + # Used by drivers that need v4l2-h264.ko + config V4L2_H264 + tristate +diff --git a/drivers/media/v4l2-core/Makefile b/drivers/media/v4l2-core/Makefile +index 2177b9d63a8f..c179507cedc4 100644 +--- a/drivers/media/v4l2-core/Makefile ++++ b/drivers/media/v4l2-core/Makefile +@@ -33,5 +33,7 @@ obj-$(CONFIG_V4L2_JPEG_HELPER) += v4l2-jpeg.o + obj-$(CONFIG_V4L2_MEM2MEM_DEV) += v4l2-mem2mem.o + obj-$(CONFIG_V4L2_VP9) += v4l2-vp9.o + ++obj-$(CONFIG_V4L2_LOOPBACK) += v4l2loopback.o ++ + obj-$(CONFIG_VIDEO_TUNER) += tuner.o + obj-$(CONFIG_VIDEO_DEV) += v4l2-dv-timings.o videodev.o +diff --git a/drivers/media/v4l2-core/v4l2loopback.c b/drivers/media/v4l2-core/v4l2loopback.c +new file mode 100644 +index 000000000000..25cb1beb26e5 +--- /dev/null ++++ b/drivers/media/v4l2-core/v4l2loopback.c +@@ -0,0 +1,3184 @@ ++/* -*- c-file-style: "linux" -*- */ ++/* ++ * v4l2loopback.c -- video4linux2 loopback driver ++ * ++ * Copyright (C) 2005-2009 Vasily Levin (vasaka@gmail.com) ++ * Copyright (C) 2010-2023 IOhannes m zmoelnig (zmoelnig@iem.at) ++ * Copyright (C) 2011 Stefan Diewald (stefan.diewald@mytum.de) ++ * Copyright (C) 2012 Anton Novikov (random.plant@gmail.com) ++ * ++ * This program is free software; you can redistribute it and/or modify ++ * it under the terms of the GNU General Public License as published by ++ * the Free Software Foundation; either version 2 of the License, or ++ * (at your option) any later version. ++ * ++ */ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#include ++#include "v4l2loopback.h" ++ ++#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 0, 0) ++#error This module is not supported on kernels before 4.0.0. ++#endif ++ ++#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) ++#define strscpy strlcpy ++#endif ++ ++#if defined(timer_setup) && defined(from_timer) ++#define HAVE_TIMER_SETUP ++#endif ++ ++#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 7, 0) ++#define VFL_TYPE_VIDEO VFL_TYPE_GRABBER ++#endif ++ ++#define V4L2LOOPBACK_VERSION_CODE \ ++ KERNEL_VERSION(V4L2LOOPBACK_VERSION_MAJOR, V4L2LOOPBACK_VERSION_MINOR, \ ++ V4L2LOOPBACK_VERSION_BUGFIX) ++ ++MODULE_DESCRIPTION("V4L2 loopback video device"); ++MODULE_AUTHOR("Vasily Levin, " ++ "IOhannes m zmoelnig ," ++ "Stefan Diewald," ++ "Anton Novikov" ++ "et al."); ++#ifdef SNAPSHOT_VERSION ++MODULE_VERSION(__stringify(SNAPSHOT_VERSION)); ++#else ++MODULE_VERSION("" __stringify(V4L2LOOPBACK_VERSION_MAJOR) "." __stringify( ++ V4L2LOOPBACK_VERSION_MINOR) "." __stringify(V4L2LOOPBACK_VERSION_BUGFIX)); ++#endif ++MODULE_LICENSE("GPL"); ++ ++/* ++ * helpers ++ */ ++#define dprintk(fmt, args...) \ ++ do { \ ++ if (debug > 0) { \ ++ printk(KERN_INFO "v4l2-loopback[" __stringify( \ ++ __LINE__) "], pid(%d): " fmt, \ ++ task_pid_nr(current), ##args); \ ++ } \ ++ } while (0) ++ ++#define MARK() \ ++ do { \ ++ if (debug > 1) { \ ++ printk(KERN_INFO "%s:%d[%s], pid(%d)\n", __FILE__, \ ++ __LINE__, __func__, task_pid_nr(current)); \ ++ } \ ++ } while (0) ++ ++#define dprintkrw(fmt, args...) \ ++ do { \ ++ if (debug > 2) { \ ++ printk(KERN_INFO "v4l2-loopback[" __stringify( \ ++ __LINE__) "], pid(%d): " fmt, \ ++ task_pid_nr(current), ##args); \ ++ } \ ++ } while (0) ++ ++static inline void v4l2l_get_timestamp(struct v4l2_buffer *b) ++{ ++ struct timespec64 ts; ++ ktime_get_ts64(&ts); ++ ++ b->timestamp.tv_sec = ts.tv_sec; ++ b->timestamp.tv_usec = (ts.tv_nsec / NSEC_PER_USEC); ++ b->flags |= V4L2_BUF_FLAG_TIMESTAMP_MONOTONIC; ++} ++ ++#if BITS_PER_LONG == 32 ++#include /* do_div() for 64bit division */ ++static inline int v4l2l_mod64(const s64 A, const u32 B) ++{ ++ u64 a = (u64)A; ++ u32 b = B; ++ ++ if (A > 0) ++ return do_div(a, b); ++ a = -A; ++ return -do_div(a, b); ++} ++#else ++static inline int v4l2l_mod64(const s64 A, const u32 B) ++{ ++ return A % B; ++} ++#endif ++ ++#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 16, 0) ++typedef unsigned __poll_t; ++#endif ++ ++/* module constants ++ * can be overridden during he build process using something like ++ * make KCPPFLAGS="-DMAX_DEVICES=100" ++ */ ++ ++/* maximum number of v4l2loopback devices that can be created */ ++#ifndef MAX_DEVICES ++#define MAX_DEVICES 8 ++#endif ++ ++/* whether the default is to announce capabilities exclusively or not */ ++#ifndef V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS ++#define V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS 0 ++#endif ++ ++/* when a producer is considered to have gone stale */ ++#ifndef MAX_TIMEOUT ++#define MAX_TIMEOUT (100 * 1000) /* in msecs */ ++#endif ++ ++/* max buffers that can be mapped, actually they ++ * are all mapped to max_buffers buffers */ ++#ifndef MAX_BUFFERS ++#define MAX_BUFFERS 32 ++#endif ++ ++/* module parameters */ ++static int debug = 0; ++module_param(debug, int, S_IRUGO | S_IWUSR); ++MODULE_PARM_DESC(debug, "debugging level (higher values == more verbose)"); ++ ++#define V4L2LOOPBACK_DEFAULT_MAX_BUFFERS 2 ++static int max_buffers = V4L2LOOPBACK_DEFAULT_MAX_BUFFERS; ++module_param(max_buffers, int, S_IRUGO); ++MODULE_PARM_DESC(max_buffers, ++ "how many buffers should be allocated [DEFAULT: " __stringify( ++ V4L2LOOPBACK_DEFAULT_MAX_BUFFERS) "]"); ++ ++/* how many times a device can be opened ++ * the per-module default value can be overridden on a per-device basis using ++ * the /sys/devices interface ++ * ++ * note that max_openers should be at least 2 in order to get a working system: ++ * one opener for the producer and one opener for the consumer ++ * however, we leave that to the user ++ */ ++#define V4L2LOOPBACK_DEFAULT_MAX_OPENERS 10 ++static int max_openers = V4L2LOOPBACK_DEFAULT_MAX_OPENERS; ++module_param(max_openers, int, S_IRUGO | S_IWUSR); ++MODULE_PARM_DESC( ++ max_openers, ++ "how many users can open the loopback device [DEFAULT: " __stringify( ++ V4L2LOOPBACK_DEFAULT_MAX_OPENERS) "]"); ++ ++static int devices = -1; ++module_param(devices, int, 0); ++MODULE_PARM_DESC(devices, "how many devices should be created"); ++ ++static int video_nr[MAX_DEVICES] = { [0 ...(MAX_DEVICES - 1)] = -1 }; ++module_param_array(video_nr, int, NULL, 0444); ++MODULE_PARM_DESC(video_nr, ++ "video device numbers (-1=auto, 0=/dev/video0, etc.)"); ++ ++static char *card_label[MAX_DEVICES]; ++module_param_array(card_label, charp, NULL, 0000); ++MODULE_PARM_DESC(card_label, "card labels for each device"); ++ ++static bool exclusive_caps[MAX_DEVICES] = { ++ [0 ...(MAX_DEVICES - 1)] = V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS ++}; ++module_param_array(exclusive_caps, bool, NULL, 0444); ++/* FIXXME: wording */ ++MODULE_PARM_DESC( ++ exclusive_caps, ++ "whether to announce OUTPUT/CAPTURE capabilities exclusively or not [DEFAULT: " __stringify( ++ V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS) "]"); ++ ++/* format specifications */ ++#define V4L2LOOPBACK_SIZE_MIN_WIDTH 2 ++#define V4L2LOOPBACK_SIZE_MIN_HEIGHT 1 ++#define V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH 8192 ++#define V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT 8192 ++ ++#define V4L2LOOPBACK_SIZE_DEFAULT_WIDTH 640 ++#define V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT 480 ++ ++static int max_width = V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH; ++module_param(max_width, int, S_IRUGO); ++MODULE_PARM_DESC(max_width, ++ "maximum allowed frame width [DEFAULT: " __stringify( ++ V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH) "]"); ++static int max_height = V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT; ++module_param(max_height, int, S_IRUGO); ++MODULE_PARM_DESC(max_height, ++ "maximum allowed frame height [DEFAULT: " __stringify( ++ V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT) "]"); ++ ++static DEFINE_IDR(v4l2loopback_index_idr); ++static DEFINE_MUTEX(v4l2loopback_ctl_mutex); ++ ++/* frame intervals */ ++#define V4L2LOOPBACK_FPS_MIN 0 ++#define V4L2LOOPBACK_FPS_MAX 1000 ++ ++/* control IDs */ ++#define V4L2LOOPBACK_CID_BASE (V4L2_CID_USER_BASE | 0xf000) ++#define CID_KEEP_FORMAT (V4L2LOOPBACK_CID_BASE + 0) ++#define CID_SUSTAIN_FRAMERATE (V4L2LOOPBACK_CID_BASE + 1) ++#define CID_TIMEOUT (V4L2LOOPBACK_CID_BASE + 2) ++#define CID_TIMEOUT_IMAGE_IO (V4L2LOOPBACK_CID_BASE + 3) ++ ++static int v4l2loopback_s_ctrl(struct v4l2_ctrl *ctrl); ++static const struct v4l2_ctrl_ops v4l2loopback_ctrl_ops = { ++ .s_ctrl = v4l2loopback_s_ctrl, ++}; ++static const struct v4l2_ctrl_config v4l2loopback_ctrl_keepformat = { ++ // clang-format off ++ .ops = &v4l2loopback_ctrl_ops, ++ .id = CID_KEEP_FORMAT, ++ .name = "keep_format", ++ .type = V4L2_CTRL_TYPE_BOOLEAN, ++ .min = 0, ++ .max = 1, ++ .step = 1, ++ .def = 0, ++ // clang-format on ++}; ++static const struct v4l2_ctrl_config v4l2loopback_ctrl_sustainframerate = { ++ // clang-format off ++ .ops = &v4l2loopback_ctrl_ops, ++ .id = CID_SUSTAIN_FRAMERATE, ++ .name = "sustain_framerate", ++ .type = V4L2_CTRL_TYPE_BOOLEAN, ++ .min = 0, ++ .max = 1, ++ .step = 1, ++ .def = 0, ++ // clang-format on ++}; ++static const struct v4l2_ctrl_config v4l2loopback_ctrl_timeout = { ++ // clang-format off ++ .ops = &v4l2loopback_ctrl_ops, ++ .id = CID_TIMEOUT, ++ .name = "timeout", ++ .type = V4L2_CTRL_TYPE_INTEGER, ++ .min = 0, ++ .max = MAX_TIMEOUT, ++ .step = 1, ++ .def = 0, ++ // clang-format on ++}; ++static const struct v4l2_ctrl_config v4l2loopback_ctrl_timeoutimageio = { ++ // clang-format off ++ .ops = &v4l2loopback_ctrl_ops, ++ .id = CID_TIMEOUT_IMAGE_IO, ++ .name = "timeout_image_io", ++ .type = V4L2_CTRL_TYPE_BUTTON, ++ .min = 0, ++ .max = 1, ++ .step = 1, ++ .def = 0, ++ // clang-format on ++}; ++ ++/* module structures */ ++struct v4l2loopback_private { ++ int device_nr; ++}; ++ ++/* TODO(vasaka) use typenames which are common to kernel, but first find out if ++ * it is needed */ ++/* struct keeping state and settings of loopback device */ ++ ++struct v4l2l_buffer { ++ struct v4l2_buffer buffer; ++ struct list_head list_head; ++ int use_count; ++}; ++ ++struct v4l2_loopback_device { ++ struct v4l2_device v4l2_dev; ++ struct v4l2_ctrl_handler ctrl_handler; ++ struct video_device *vdev; ++ /* pixel and stream format */ ++ struct v4l2_pix_format pix_format; ++ bool pix_format_has_valid_sizeimage; ++ struct v4l2_captureparm capture_param; ++ unsigned long frame_jiffies; ++ ++ /* ctrls */ ++ int keep_format; /* CID_KEEP_FORMAT; stay ready_for_capture even when all ++ openers close() the device */ ++ int sustain_framerate; /* CID_SUSTAIN_FRAMERATE; duplicate frames to maintain ++ (close to) nominal framerate */ ++ ++ /* buffers stuff */ ++ u8 *image; /* pointer to actual buffers data */ ++ unsigned long int imagesize; /* size of buffers data */ ++ int buffers_number; /* should not be big, 4 is a good choice */ ++ struct v4l2l_buffer buffers[MAX_BUFFERS]; /* inner driver buffers */ ++ int used_buffers; /* number of the actually used buffers */ ++ int max_openers; /* how many times can this device be opened */ ++ ++ s64 write_position; /* number of last written frame + 1 */ ++ struct list_head outbufs_list; /* buffers in output DQBUF order */ ++ int bufpos2index ++ [MAX_BUFFERS]; /* mapping of (read/write_position % used_buffers) ++ * to inner buffer index */ ++ long buffer_size; ++ ++ /* sustain_framerate stuff */ ++ struct timer_list sustain_timer; ++ unsigned int reread_count; ++ ++ /* timeout stuff */ ++ unsigned long timeout_jiffies; /* CID_TIMEOUT; 0 means disabled */ ++ int timeout_image_io; /* CID_TIMEOUT_IMAGE_IO; next opener will ++ * read/write to timeout_image */ ++ u8 *timeout_image; /* copy of it will be captured when timeout passes */ ++ struct v4l2l_buffer timeout_image_buffer; ++ struct timer_list timeout_timer; ++ int timeout_happened; ++ ++ /* sync stuff */ ++ atomic_t open_count; ++ ++ int ready_for_capture; /* set to the number of writers that opened the ++ * device and negotiated format. */ ++ int ready_for_output; /* set to true when no writer is currently attached ++ * this differs slightly from !ready_for_capture, ++ * e.g. when using fallback images */ ++ int active_readers; /* increase if any reader starts streaming */ ++ int announce_all_caps; /* set to false, if device caps (OUTPUT/CAPTURE) ++ * should only be announced if the resp. "ready" ++ * flag is set; default=TRUE */ ++ ++ int min_width, max_width; ++ int min_height, max_height; ++ ++ char card_label[32]; ++ ++ wait_queue_head_t read_event; ++ spinlock_t lock, list_lock; ++}; ++ ++/* types of opener shows what opener wants to do with loopback */ ++enum opener_type { ++ // clang-format off ++ UNNEGOTIATED = 0, ++ READER = 1, ++ WRITER = 2, ++ // clang-format on ++}; ++ ++/* struct keeping state and type of opener */ ++struct v4l2_loopback_opener { ++ enum opener_type type; ++ s64 read_position; /* number of last processed frame + 1 or ++ * write_position - 1 if reader went out of sync */ ++ unsigned int reread_count; ++ struct v4l2_buffer *buffers; ++ int buffers_number; /* should not be big, 4 is a good choice */ ++ int timeout_image_io; ++ ++ struct v4l2_fh fh; ++}; ++ ++#define fh_to_opener(ptr) container_of((ptr), struct v4l2_loopback_opener, fh) ++ ++/* this is heavily inspired by the bttv driver found in the linux kernel */ ++struct v4l2l_format { ++ char *name; ++ int fourcc; /* video4linux 2 */ ++ int depth; /* bit/pixel */ ++ int flags; ++}; ++/* set the v4l2l_format.flags to PLANAR for non-packed formats */ ++#define FORMAT_FLAGS_PLANAR 0x01 ++#define FORMAT_FLAGS_COMPRESSED 0x02 ++ ++#include "v4l2loopback_formats.h" ++ ++#ifndef V4L2_TYPE_IS_CAPTURE ++#define V4L2_TYPE_IS_CAPTURE(type) \ ++ ((type) == V4L2_BUF_TYPE_VIDEO_CAPTURE || \ ++ (type) == V4L2_BUF_TYPE_VIDEO_CAPTURE_MPLANE) ++#endif /* V4L2_TYPE_IS_CAPTURE */ ++#ifndef V4L2_TYPE_IS_OUTPUT ++#define V4L2_TYPE_IS_OUTPUT(type) \ ++ ((type) == V4L2_BUF_TYPE_VIDEO_OUTPUT || \ ++ (type) == V4L2_BUF_TYPE_VIDEO_OUTPUT_MPLANE) ++#endif /* V4L2_TYPE_IS_OUTPUT */ ++ ++/* whether the format can be changed */ ++/* the format is fixated if we ++ - have writers (ready_for_capture>0) ++ - and/or have readers (active_readers>0) ++*/ ++#define V4L2LOOPBACK_IS_FIXED_FMT(device) \ ++ (device->ready_for_capture > 0 || device->active_readers > 0 || \ ++ device->keep_format) ++ ++static const unsigned int FORMATS = ARRAY_SIZE(formats); ++ ++static char *fourcc2str(unsigned int fourcc, char buf[4]) ++{ ++ buf[0] = (fourcc >> 0) & 0xFF; ++ buf[1] = (fourcc >> 8) & 0xFF; ++ buf[2] = (fourcc >> 16) & 0xFF; ++ buf[3] = (fourcc >> 24) & 0xFF; ++ ++ return buf; ++} ++ ++static const struct v4l2l_format *format_by_fourcc(int fourcc) ++{ ++ unsigned int i; ++ ++ for (i = 0; i < FORMATS; i++) { ++ if (formats[i].fourcc == fourcc) ++ return formats + i; ++ } ++ ++ dprintk("unsupported format '%c%c%c%c'\n", (fourcc >> 0) & 0xFF, ++ (fourcc >> 8) & 0xFF, (fourcc >> 16) & 0xFF, ++ (fourcc >> 24) & 0xFF); ++ return NULL; ++} ++ ++static void pix_format_set_size(struct v4l2_pix_format *f, ++ const struct v4l2l_format *fmt, ++ unsigned int width, unsigned int height) ++{ ++ f->width = width; ++ f->height = height; ++ ++ if (fmt->flags & FORMAT_FLAGS_PLANAR) { ++ f->bytesperline = width; /* Y plane */ ++ f->sizeimage = (width * height * fmt->depth) >> 3; ++ } else if (fmt->flags & FORMAT_FLAGS_COMPRESSED) { ++ /* doesn't make sense for compressed formats */ ++ f->bytesperline = 0; ++ f->sizeimage = (width * height * fmt->depth) >> 3; ++ } else { ++ f->bytesperline = (width * fmt->depth) >> 3; ++ f->sizeimage = height * f->bytesperline; ++ } ++} ++ ++static int v4l2l_fill_format(struct v4l2_format *fmt, int capture, ++ const u32 minwidth, const u32 maxwidth, ++ const u32 minheight, const u32 maxheight) ++{ ++ u32 width = fmt->fmt.pix.width, height = fmt->fmt.pix.height; ++ u32 pixelformat = fmt->fmt.pix.pixelformat; ++ struct v4l2_format fmt0 = *fmt; ++ u32 bytesperline = 0, sizeimage = 0; ++ if (!width) ++ width = V4L2LOOPBACK_SIZE_DEFAULT_WIDTH; ++ if (!height) ++ height = V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT; ++ if (width < minwidth) ++ width = minwidth; ++ if (width > maxwidth) ++ width = maxwidth; ++ if (height < minheight) ++ height = minheight; ++ if (height > maxheight) ++ height = maxheight; ++ ++ /* sets: width,height,pixelformat,bytesperline,sizeimage */ ++ if (!(V4L2_TYPE_IS_MULTIPLANAR(fmt0.type))) { ++ fmt0.fmt.pix.bytesperline = 0; ++ fmt0.fmt.pix.sizeimage = 0; ++ } ++ ++ if (0) { ++ ; ++#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 2, 0) ++ } else if (!v4l2_fill_pixfmt(&fmt0.fmt.pix, pixelformat, width, ++ height)) { ++ ; ++ } else if (!v4l2_fill_pixfmt_mp(&fmt0.fmt.pix_mp, pixelformat, width, ++ height)) { ++ ; ++#endif ++ } else { ++ const struct v4l2l_format *format = ++ format_by_fourcc(pixelformat); ++ if (!format) ++ return -EINVAL; ++ pix_format_set_size(&fmt0.fmt.pix, format, width, height); ++ fmt0.fmt.pix.pixelformat = format->fourcc; ++ } ++ ++ if (V4L2_TYPE_IS_MULTIPLANAR(fmt0.type)) { ++ *fmt = fmt0; ++ ++ if ((fmt->fmt.pix_mp.colorspace == V4L2_COLORSPACE_DEFAULT) || ++ (fmt->fmt.pix_mp.colorspace > V4L2_COLORSPACE_DCI_P3)) ++ fmt->fmt.pix_mp.colorspace = V4L2_COLORSPACE_SRGB; ++ if (V4L2_FIELD_ANY == fmt->fmt.pix_mp.field) ++ fmt->fmt.pix_mp.field = V4L2_FIELD_NONE; ++ if (capture) ++ fmt->type = V4L2_BUF_TYPE_VIDEO_CAPTURE_MPLANE; ++ else ++ fmt->type = V4L2_BUF_TYPE_VIDEO_OUTPUT_MPLANE; ++ } else { ++ bytesperline = fmt->fmt.pix.bytesperline; ++ sizeimage = fmt->fmt.pix.sizeimage; ++ ++ *fmt = fmt0; ++ ++ if (!fmt->fmt.pix.bytesperline) ++ fmt->fmt.pix.bytesperline = bytesperline; ++ if (!fmt->fmt.pix.sizeimage) ++ fmt->fmt.pix.sizeimage = sizeimage; ++ ++ if ((fmt->fmt.pix.colorspace == V4L2_COLORSPACE_DEFAULT) || ++ (fmt->fmt.pix.colorspace > V4L2_COLORSPACE_DCI_P3)) ++ fmt->fmt.pix.colorspace = V4L2_COLORSPACE_SRGB; ++ if (V4L2_FIELD_ANY == fmt->fmt.pix.field) ++ fmt->fmt.pix.field = V4L2_FIELD_NONE; ++ if (capture) ++ fmt->type = V4L2_BUF_TYPE_VIDEO_CAPTURE; ++ else ++ fmt->type = V4L2_BUF_TYPE_VIDEO_OUTPUT; ++ } ++ ++ return 0; ++} ++ ++/* Checks if v4l2l_fill_format() has set a valid, fixed sizeimage val. */ ++static bool v4l2l_pix_format_has_valid_sizeimage(struct v4l2_format *fmt) ++{ ++#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 2, 0) ++ const struct v4l2_format_info *info; ++ ++ info = v4l2_format_info(fmt->fmt.pix.pixelformat); ++ if (info && info->mem_planes == 1) ++ return true; ++#endif ++ ++ return false; ++} ++ ++static int pix_format_eq(const struct v4l2_pix_format *ref, ++ const struct v4l2_pix_format *tgt, int strict) ++{ ++ /* check if the two formats are equivalent. ++ * ANY fields are handled gracefully ++ */ ++#define _pix_format_eq0(x) \ ++ if (ref->x != tgt->x) \ ++ result = 0 ++#define _pix_format_eq1(x, def) \ ++ do { \ ++ if ((def != tgt->x) && (ref->x != tgt->x)) { \ ++ printk(KERN_INFO #x " failed"); \ ++ result = 0; \ ++ } \ ++ } while (0) ++ int result = 1; ++ _pix_format_eq0(width); ++ _pix_format_eq0(height); ++ _pix_format_eq0(pixelformat); ++ if (!strict) ++ return result; ++ _pix_format_eq1(field, V4L2_FIELD_ANY); ++ _pix_format_eq0(bytesperline); ++ _pix_format_eq0(sizeimage); ++ _pix_format_eq1(colorspace, V4L2_COLORSPACE_DEFAULT); ++ return result; ++} ++ ++static struct v4l2_loopback_device *v4l2loopback_getdevice(struct file *f); ++static int inner_try_setfmt(struct file *file, struct v4l2_format *fmt) ++{ ++ int capture = V4L2_TYPE_IS_CAPTURE(fmt->type); ++ struct v4l2_loopback_device *dev; ++ int needschange = 0; ++ char buf[5]; ++ buf[4] = 0; ++ ++ dev = v4l2loopback_getdevice(file); ++ ++ needschange = !(pix_format_eq(&dev->pix_format, &fmt->fmt.pix, 0)); ++ if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { ++ fmt->fmt.pix = dev->pix_format; ++ if (needschange) { ++ if (dev->active_readers > 0 && capture) { ++ /* cannot call fmt_cap while there are readers */ ++ return -EBUSY; ++ } ++ if (dev->ready_for_capture > 0 && !capture) { ++ /* cannot call fmt_out while there are writers */ ++ return -EBUSY; ++ } ++ } ++ } ++ if (v4l2l_fill_format(fmt, capture, dev->min_width, dev->max_width, ++ dev->min_height, dev->max_height) != 0) { ++ return -EINVAL; ++ } ++ ++ if (1) { ++ char buf[5]; ++ buf[4] = 0; ++ dprintk("capFOURCC=%s\n", ++ fourcc2str(dev->pix_format.pixelformat, buf)); ++ } ++ return 0; ++} ++ ++static int set_timeperframe(struct v4l2_loopback_device *dev, ++ struct v4l2_fract *tpf) ++{ ++ if ((tpf->denominator < 1) || (tpf->numerator < 1)) { ++ return -EINVAL; ++ } ++ dev->capture_param.timeperframe = *tpf; ++ dev->frame_jiffies = max(1UL, msecs_to_jiffies(1000) * tpf->numerator / ++ tpf->denominator); ++ return 0; ++} ++ ++static struct v4l2_loopback_device *v4l2loopback_cd2dev(struct device *cd); ++ ++/* device attributes */ ++/* available via sysfs: /sys/devices/virtual/video4linux/video* */ ++ ++static ssize_t attr_show_format(struct device *cd, ++ struct device_attribute *attr, char *buf) ++{ ++ /* gets the current format as "FOURCC:WxH@f/s", e.g. "YUYV:320x240@1000/30" */ ++ struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); ++ const struct v4l2_fract *tpf; ++ char buf4cc[5], buf_fps[32]; ++ ++ if (!dev || !V4L2LOOPBACK_IS_FIXED_FMT(dev)) ++ return 0; ++ tpf = &dev->capture_param.timeperframe; ++ ++ fourcc2str(dev->pix_format.pixelformat, buf4cc); ++ buf4cc[4] = 0; ++ if (tpf->numerator == 1) ++ snprintf(buf_fps, sizeof(buf_fps), "%d", tpf->denominator); ++ else ++ snprintf(buf_fps, sizeof(buf_fps), "%d/%d", tpf->denominator, ++ tpf->numerator); ++ return sprintf(buf, "%4s:%dx%d@%s\n", buf4cc, dev->pix_format.width, ++ dev->pix_format.height, buf_fps); ++} ++ ++static ssize_t attr_store_format(struct device *cd, ++ struct device_attribute *attr, const char *buf, ++ size_t len) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); ++ int fps_num = 0, fps_den = 1; ++ ++ if (!dev) ++ return -ENODEV; ++ ++ /* only fps changing is supported */ ++ if (sscanf(buf, "@%d/%d", &fps_num, &fps_den) > 0) { ++ struct v4l2_fract f = { .numerator = fps_den, ++ .denominator = fps_num }; ++ int err = 0; ++ if ((err = set_timeperframe(dev, &f)) < 0) ++ return err; ++ return len; ++ } ++ return -EINVAL; ++} ++ ++static DEVICE_ATTR(format, S_IRUGO | S_IWUSR, attr_show_format, ++ attr_store_format); ++ ++static ssize_t attr_show_buffers(struct device *cd, ++ struct device_attribute *attr, char *buf) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); ++ ++ if (!dev) ++ return -ENODEV; ++ ++ return sprintf(buf, "%d\n", dev->used_buffers); ++} ++ ++static DEVICE_ATTR(buffers, S_IRUGO, attr_show_buffers, NULL); ++ ++static ssize_t attr_show_maxopeners(struct device *cd, ++ struct device_attribute *attr, char *buf) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); ++ ++ if (!dev) ++ return -ENODEV; ++ ++ return sprintf(buf, "%d\n", dev->max_openers); ++} ++ ++static ssize_t attr_store_maxopeners(struct device *cd, ++ struct device_attribute *attr, ++ const char *buf, size_t len) ++{ ++ struct v4l2_loopback_device *dev = NULL; ++ unsigned long curr = 0; ++ ++ if (kstrtoul(buf, 0, &curr)) ++ return -EINVAL; ++ ++ dev = v4l2loopback_cd2dev(cd); ++ if (!dev) ++ return -ENODEV; ++ ++ if (dev->max_openers == curr) ++ return len; ++ ++ if (curr > __INT_MAX__ || dev->open_count.counter > curr) { ++ /* request to limit to less openers as are currently attached to us */ ++ return -EINVAL; ++ } ++ ++ dev->max_openers = (int)curr; ++ ++ return len; ++} ++ ++static DEVICE_ATTR(max_openers, S_IRUGO | S_IWUSR, attr_show_maxopeners, ++ attr_store_maxopeners); ++ ++static ssize_t attr_show_state(struct device *cd, struct device_attribute *attr, ++ char *buf) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); ++ ++ if (!dev) ++ return -ENODEV; ++ ++ if (dev->ready_for_capture) ++ return sprintf(buf, "capture\n"); ++ if (dev->ready_for_output) ++ return sprintf(buf, "output\n"); ++ ++ return -EAGAIN; ++} ++ ++static DEVICE_ATTR(state, S_IRUGO, attr_show_state, NULL); ++ ++static void v4l2loopback_remove_sysfs(struct video_device *vdev) ++{ ++#define V4L2_SYSFS_DESTROY(x) device_remove_file(&vdev->dev, &dev_attr_##x) ++ ++ if (vdev) { ++ V4L2_SYSFS_DESTROY(format); ++ V4L2_SYSFS_DESTROY(buffers); ++ V4L2_SYSFS_DESTROY(max_openers); ++ V4L2_SYSFS_DESTROY(state); ++ /* ... */ ++ } ++} ++ ++static void v4l2loopback_create_sysfs(struct video_device *vdev) ++{ ++ int res = 0; ++ ++#define V4L2_SYSFS_CREATE(x) \ ++ res = device_create_file(&vdev->dev, &dev_attr_##x); \ ++ if (res < 0) \ ++ break ++ if (!vdev) ++ return; ++ do { ++ V4L2_SYSFS_CREATE(format); ++ V4L2_SYSFS_CREATE(buffers); ++ V4L2_SYSFS_CREATE(max_openers); ++ V4L2_SYSFS_CREATE(state); ++ /* ... */ ++ } while (0); ++ ++ if (res >= 0) ++ return; ++ dev_err(&vdev->dev, "%s error: %d\n", __func__, res); ++} ++ ++/* Event APIs */ ++ ++#define V4L2LOOPBACK_EVENT_BASE (V4L2_EVENT_PRIVATE_START) ++#define V4L2LOOPBACK_EVENT_OFFSET 0x08E00000 ++#define V4L2_EVENT_PRI_CLIENT_USAGE \ ++ (V4L2LOOPBACK_EVENT_BASE + V4L2LOOPBACK_EVENT_OFFSET + 1) ++ ++struct v4l2_event_client_usage { ++ __u32 count; ++}; ++ ++/* global module data */ ++/* find a device based on it's device-number (e.g. '3' for /dev/video3) */ ++struct v4l2loopback_lookup_cb_data { ++ int device_nr; ++ struct v4l2_loopback_device *device; ++}; ++static int v4l2loopback_lookup_cb(int id, void *ptr, void *data) ++{ ++ struct v4l2_loopback_device *device = ptr; ++ struct v4l2loopback_lookup_cb_data *cbdata = data; ++ if (cbdata && device && device->vdev) { ++ if (device->vdev->num == cbdata->device_nr) { ++ cbdata->device = device; ++ cbdata->device_nr = id; ++ return 1; ++ } ++ } ++ return 0; ++} ++static int v4l2loopback_lookup(int device_nr, ++ struct v4l2_loopback_device **device) ++{ ++ struct v4l2loopback_lookup_cb_data data = { ++ .device_nr = device_nr, ++ .device = NULL, ++ }; ++ int err = idr_for_each(&v4l2loopback_index_idr, &v4l2loopback_lookup_cb, ++ &data); ++ if (1 == err) { ++ if (device) ++ *device = data.device; ++ return data.device_nr; ++ } ++ return -ENODEV; ++} ++static struct v4l2_loopback_device *v4l2loopback_cd2dev(struct device *cd) ++{ ++ struct video_device *loopdev = to_video_device(cd); ++ struct v4l2loopback_private *ptr = ++ (struct v4l2loopback_private *)video_get_drvdata(loopdev); ++ int nr = ptr->device_nr; ++ ++ return idr_find(&v4l2loopback_index_idr, nr); ++} ++ ++static struct v4l2_loopback_device *v4l2loopback_getdevice(struct file *f) ++{ ++ struct v4l2loopback_private *ptr = video_drvdata(f); ++ int nr = ptr->device_nr; ++ ++ return idr_find(&v4l2loopback_index_idr, nr); ++} ++ ++/* forward declarations */ ++static void client_usage_queue_event(struct video_device *vdev); ++static void init_buffers(struct v4l2_loopback_device *dev); ++static int allocate_buffers(struct v4l2_loopback_device *dev); ++static void free_buffers(struct v4l2_loopback_device *dev); ++static void try_free_buffers(struct v4l2_loopback_device *dev); ++static int allocate_timeout_image(struct v4l2_loopback_device *dev); ++static void check_timers(struct v4l2_loopback_device *dev); ++static const struct v4l2_file_operations v4l2_loopback_fops; ++static const struct v4l2_ioctl_ops v4l2_loopback_ioctl_ops; ++ ++/* Queue helpers */ ++/* next functions sets buffer flags and adjusts counters accordingly */ ++static inline void set_done(struct v4l2l_buffer *buffer) ++{ ++ buffer->buffer.flags &= ~V4L2_BUF_FLAG_QUEUED; ++ buffer->buffer.flags |= V4L2_BUF_FLAG_DONE; ++} ++ ++static inline void set_queued(struct v4l2l_buffer *buffer) ++{ ++ buffer->buffer.flags &= ~V4L2_BUF_FLAG_DONE; ++ buffer->buffer.flags |= V4L2_BUF_FLAG_QUEUED; ++} ++ ++static inline void unset_flags(struct v4l2l_buffer *buffer) ++{ ++ buffer->buffer.flags &= ~V4L2_BUF_FLAG_QUEUED; ++ buffer->buffer.flags &= ~V4L2_BUF_FLAG_DONE; ++} ++ ++/* V4L2 ioctl caps and params calls */ ++/* returns device capabilities ++ * called on VIDIOC_QUERYCAP ++ */ ++static int vidioc_querycap(struct file *file, void *priv, ++ struct v4l2_capability *cap) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ int device_nr = ++ ((struct v4l2loopback_private *)video_get_drvdata(dev->vdev)) ++ ->device_nr; ++ __u32 capabilities = V4L2_CAP_STREAMING | V4L2_CAP_READWRITE; ++ ++ strscpy(cap->driver, "v4l2 loopback", sizeof(cap->driver)); ++ snprintf(cap->card, sizeof(cap->card), "%s", dev->card_label); ++ snprintf(cap->bus_info, sizeof(cap->bus_info), ++ "platform:v4l2loopback-%03d", device_nr); ++ ++ if (dev->announce_all_caps) { ++ capabilities |= V4L2_CAP_VIDEO_CAPTURE | V4L2_CAP_VIDEO_OUTPUT; ++ } else { ++ if (dev->ready_for_capture) { ++ capabilities |= V4L2_CAP_VIDEO_CAPTURE; ++ } ++ if (dev->ready_for_output) { ++ capabilities |= V4L2_CAP_VIDEO_OUTPUT; ++ } ++ } ++ ++#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 7, 0) ++ dev->vdev->device_caps = ++#endif /* >=linux-4.7.0 */ ++ cap->device_caps = cap->capabilities = capabilities; ++ ++ cap->capabilities |= V4L2_CAP_DEVICE_CAPS; ++ ++ memset(cap->reserved, 0, sizeof(cap->reserved)); ++ return 0; ++} ++ ++static int vidioc_enum_framesizes(struct file *file, void *fh, ++ struct v4l2_frmsizeenum *argp) ++{ ++ struct v4l2_loopback_device *dev; ++ ++ /* there can be only one... */ ++ if (argp->index) ++ return -EINVAL; ++ ++ dev = v4l2loopback_getdevice(file); ++ if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { ++ /* format has already been negotiated ++ * cannot change during runtime ++ */ ++ if (argp->pixel_format != dev->pix_format.pixelformat) ++ return -EINVAL; ++ ++ argp->type = V4L2_FRMSIZE_TYPE_DISCRETE; ++ ++ argp->discrete.width = dev->pix_format.width; ++ argp->discrete.height = dev->pix_format.height; ++ } else { ++ /* if the format has not been negotiated yet, we accept anything ++ */ ++ if (NULL == format_by_fourcc(argp->pixel_format)) ++ return -EINVAL; ++ ++ if (dev->min_width == dev->max_width && ++ dev->min_height == dev->max_height) { ++ argp->type = V4L2_FRMSIZE_TYPE_DISCRETE; ++ ++ argp->discrete.width = dev->min_width; ++ argp->discrete.height = dev->min_height; ++ } else { ++ argp->type = V4L2_FRMSIZE_TYPE_CONTINUOUS; ++ ++ argp->stepwise.min_width = dev->min_width; ++ argp->stepwise.min_height = dev->min_height; ++ ++ argp->stepwise.max_width = dev->max_width; ++ argp->stepwise.max_height = dev->max_height; ++ ++ argp->stepwise.step_width = 1; ++ argp->stepwise.step_height = 1; ++ } ++ } ++ return 0; ++} ++ ++/* returns frameinterval (fps) for the set resolution ++ * called on VIDIOC_ENUM_FRAMEINTERVALS ++ */ ++static int vidioc_enum_frameintervals(struct file *file, void *fh, ++ struct v4l2_frmivalenum *argp) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ ++ /* there can be only one... */ ++ if (argp->index) ++ return -EINVAL; ++ ++ if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { ++ if (argp->width != dev->pix_format.width || ++ argp->height != dev->pix_format.height || ++ argp->pixel_format != dev->pix_format.pixelformat) ++ return -EINVAL; ++ ++ argp->type = V4L2_FRMIVAL_TYPE_DISCRETE; ++ argp->discrete = dev->capture_param.timeperframe; ++ } else { ++ if (argp->width < dev->min_width || ++ argp->width > dev->max_width || ++ argp->height < dev->min_height || ++ argp->height > dev->max_height || ++ NULL == format_by_fourcc(argp->pixel_format)) ++ return -EINVAL; ++ ++ argp->type = V4L2_FRMIVAL_TYPE_CONTINUOUS; ++ argp->stepwise.min.numerator = 1; ++ argp->stepwise.min.denominator = V4L2LOOPBACK_FPS_MAX; ++ argp->stepwise.max.numerator = 1; ++ argp->stepwise.max.denominator = V4L2LOOPBACK_FPS_MIN; ++ argp->stepwise.step.numerator = 1; ++ argp->stepwise.step.denominator = 1; ++ } ++ ++ return 0; ++} ++ ++/* ------------------ CAPTURE ----------------------- */ ++ ++/* returns device formats ++ * called on VIDIOC_ENUM_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE ++ */ ++static int vidioc_enum_fmt_cap(struct file *file, void *fh, ++ struct v4l2_fmtdesc *f) ++{ ++ struct v4l2_loopback_device *dev; ++ const struct v4l2l_format *fmt; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ ++ if (f->index) ++ return -EINVAL; ++ ++ if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { ++ /* format has been fixed, so only one single format is supported */ ++ const __u32 format = dev->pix_format.pixelformat; ++ ++ if ((fmt = format_by_fourcc(format))) { ++ snprintf(f->description, sizeof(f->description), "%s", ++ fmt->name); ++ } else { ++ snprintf(f->description, sizeof(f->description), ++ "[%c%c%c%c]", (format >> 0) & 0xFF, ++ (format >> 8) & 0xFF, (format >> 16) & 0xFF, ++ (format >> 24) & 0xFF); ++ } ++ ++ f->pixelformat = dev->pix_format.pixelformat; ++ } else { ++ return -EINVAL; ++ } ++ f->flags = 0; ++ MARK(); ++ return 0; ++} ++ ++/* returns current video format ++ * called on VIDIOC_G_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE ++ */ ++static int vidioc_g_fmt_cap(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ struct v4l2_loopback_device *dev; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ if (!dev->ready_for_capture && !dev->ready_for_output) ++ return -EINVAL; ++ ++ fmt->fmt.pix = dev->pix_format; ++ MARK(); ++ return 0; ++} ++ ++/* checks if it is OK to change to format fmt; ++ * actual check is done by inner_try_setfmt ++ * just checking that pixelformat is OK and set other parameters, app should ++ * obey this decision ++ * called on VIDIOC_TRY_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE ++ */ ++static int vidioc_try_fmt_cap(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ int ret = 0; ++ if (!V4L2_TYPE_IS_CAPTURE(fmt->type)) ++ return -EINVAL; ++ ret = inner_try_setfmt(file, fmt); ++ if (-EBUSY == ret) ++ return 0; ++ return ret; ++} ++ ++/* sets new output format, if possible ++ * actually format is set by input and we even do not check it, just return ++ * current one, but it is possible to set subregions of input TODO(vasaka) ++ * called on VIDIOC_S_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE ++ */ ++static int vidioc_s_fmt_cap(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ int ret; ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ if (!V4L2_TYPE_IS_CAPTURE(fmt->type)) ++ return -EINVAL; ++ ret = inner_try_setfmt(file, fmt); ++ if (!ret) { ++ dev->pix_format = fmt->fmt.pix; ++ } ++ return ret; ++} ++ ++/* ------------------ OUTPUT ----------------------- */ ++ ++/* returns device formats; ++ * LATER: allow all formats ++ * called on VIDIOC_ENUM_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT ++ */ ++static int vidioc_enum_fmt_out(struct file *file, void *fh, ++ struct v4l2_fmtdesc *f) ++{ ++ struct v4l2_loopback_device *dev; ++ const struct v4l2l_format *fmt; ++ ++ dev = v4l2loopback_getdevice(file); ++ ++ if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { ++ /* format has been fixed, so only one single format is supported */ ++ const __u32 format = dev->pix_format.pixelformat; ++ ++ if (f->index) ++ return -EINVAL; ++ ++ if ((fmt = format_by_fourcc(format))) { ++ snprintf(f->description, sizeof(f->description), "%s", ++ fmt->name); ++ } else { ++ snprintf(f->description, sizeof(f->description), ++ "[%c%c%c%c]", (format >> 0) & 0xFF, ++ (format >> 8) & 0xFF, (format >> 16) & 0xFF, ++ (format >> 24) & 0xFF); ++ } ++ ++ f->pixelformat = dev->pix_format.pixelformat; ++ } else { ++ /* fill in a dummy format */ ++ /* coverity[unsigned_compare] */ ++ if (f->index < 0 || f->index >= FORMATS) ++ return -EINVAL; ++ ++ fmt = &formats[f->index]; ++ ++ f->pixelformat = fmt->fourcc; ++ snprintf(f->description, sizeof(f->description), "%s", ++ fmt->name); ++ } ++ f->flags = 0; ++ ++ return 0; ++} ++ ++/* returns current video format format fmt */ ++/* NOTE: this is called from the producer ++ * so if format has not been negotiated yet, ++ * it should return ALL of available formats, ++ * called on VIDIOC_G_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT ++ */ ++static int vidioc_g_fmt_out(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ struct v4l2_loopback_device *dev; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ ++ /* ++ * LATER: this should return the currently valid format ++ * gstreamer doesn't like it, if this returns -EINVAL, as it ++ * then concludes that there is _no_ valid format ++ * CHECK whether this assumption is wrong, ++ * or whether we have to always provide a valid format ++ */ ++ ++ fmt->fmt.pix = dev->pix_format; ++ return 0; ++} ++ ++/* checks if it is OK to change to format fmt; ++ * if format is negotiated do not change it ++ * called on VIDIOC_TRY_FMT with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT ++ */ ++static int vidioc_try_fmt_out(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ int ret = 0; ++ if (!V4L2_TYPE_IS_OUTPUT(fmt->type)) ++ return -EINVAL; ++ ret = inner_try_setfmt(file, fmt); ++ if (-EBUSY == ret) ++ return 0; ++ return ret; ++} ++ ++/* sets new output format, if possible; ++ * allocate data here because we do not know if it will be streaming or ++ * read/write IO ++ * called on VIDIOC_S_FMT with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT ++ */ ++static int vidioc_s_fmt_out(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ struct v4l2_loopback_device *dev; ++ int ret; ++ char buf[5]; ++ buf[4] = 0; ++ if (!V4L2_TYPE_IS_OUTPUT(fmt->type)) ++ return -EINVAL; ++ dev = v4l2loopback_getdevice(file); ++ ++ ret = inner_try_setfmt(file, fmt); ++ if (!ret) { ++ dev->pix_format = fmt->fmt.pix; ++ dev->pix_format_has_valid_sizeimage = ++ v4l2l_pix_format_has_valid_sizeimage(fmt); ++ dprintk("s_fmt_out(%d) %d...%d\n", ret, dev->ready_for_capture, ++ dev->pix_format.sizeimage); ++ dprintk("outFOURCC=%s\n", ++ fourcc2str(dev->pix_format.pixelformat, buf)); ++ ++ if (!dev->ready_for_capture) { ++ dev->buffer_size = ++ PAGE_ALIGN(dev->pix_format.sizeimage); ++ // JMZ: TODO get rid of the next line ++ fmt->fmt.pix.sizeimage = dev->buffer_size; ++ ret = allocate_buffers(dev); ++ } ++ } ++ return ret; ++} ++ ++// #define V4L2L_OVERLAY ++#ifdef V4L2L_OVERLAY ++/* ------------------ OVERLAY ----------------------- */ ++/* currently unsupported */ ++/* GSTreamer's v4l2sink is buggy, as it requires the overlay to work ++ * while it should only require it, if overlay is requested ++ * once the gstreamer element is fixed, remove the overlay dummies ++ */ ++#warning OVERLAY dummies ++static int vidioc_g_fmt_overlay(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ return 0; ++} ++ ++static int vidioc_s_fmt_overlay(struct file *file, void *priv, ++ struct v4l2_format *fmt) ++{ ++ return 0; ++} ++#endif /* V4L2L_OVERLAY */ ++ ++/* ------------------ PARAMs ----------------------- */ ++ ++/* get some data flow parameters, only capability, fps and readbuffers has ++ * effect on this driver ++ * called on VIDIOC_G_PARM ++ */ ++static int vidioc_g_parm(struct file *file, void *priv, ++ struct v4l2_streamparm *parm) ++{ ++ /* do not care about type of opener, hope these enums would always be ++ * compatible */ ++ struct v4l2_loopback_device *dev; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ parm->parm.capture = dev->capture_param; ++ return 0; ++} ++ ++/* get some data flow parameters, only capability, fps and readbuffers has ++ * effect on this driver ++ * called on VIDIOC_S_PARM ++ */ ++static int vidioc_s_parm(struct file *file, void *priv, ++ struct v4l2_streamparm *parm) ++{ ++ struct v4l2_loopback_device *dev; ++ int err = 0; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ dprintk("vidioc_s_parm called frate=%d/%d\n", ++ parm->parm.capture.timeperframe.numerator, ++ parm->parm.capture.timeperframe.denominator); ++ ++ switch (parm->type) { ++ case V4L2_BUF_TYPE_VIDEO_CAPTURE: ++ if ((err = set_timeperframe( ++ dev, &parm->parm.capture.timeperframe)) < 0) ++ return err; ++ break; ++ case V4L2_BUF_TYPE_VIDEO_OUTPUT: ++ if ((err = set_timeperframe( ++ dev, &parm->parm.capture.timeperframe)) < 0) ++ return err; ++ break; ++ default: ++ return -1; ++ } ++ ++ parm->parm.capture = dev->capture_param; ++ return 0; ++} ++ ++#ifdef V4L2LOOPBACK_WITH_STD ++/* sets a tv standard, actually we do not need to handle this any special way ++ * added to support effecttv ++ * called on VIDIOC_S_STD ++ */ ++static int vidioc_s_std(struct file *file, void *fh, v4l2_std_id *_std) ++{ ++ v4l2_std_id req_std = 0, supported_std = 0; ++ const v4l2_std_id all_std = V4L2_STD_ALL, no_std = 0; ++ ++ if (_std) { ++ req_std = *_std; ++ *_std = all_std; ++ } ++ ++ /* we support everything in V4L2_STD_ALL, but not more... */ ++ supported_std = (all_std & req_std); ++ if (no_std == supported_std) ++ return -EINVAL; ++ ++ return 0; ++} ++ ++/* gets a fake video standard ++ * called on VIDIOC_G_STD ++ */ ++static int vidioc_g_std(struct file *file, void *fh, v4l2_std_id *norm) ++{ ++ if (norm) ++ *norm = V4L2_STD_ALL; ++ return 0; ++} ++/* gets a fake video standard ++ * called on VIDIOC_QUERYSTD ++ */ ++static int vidioc_querystd(struct file *file, void *fh, v4l2_std_id *norm) ++{ ++ if (norm) ++ *norm = V4L2_STD_ALL; ++ return 0; ++} ++#endif /* V4L2LOOPBACK_WITH_STD */ ++ ++static int v4l2loopback_set_ctrl(struct v4l2_loopback_device *dev, u32 id, ++ s64 val) ++{ ++ switch (id) { ++ case CID_KEEP_FORMAT: ++ if (val < 0 || val > 1) ++ return -EINVAL; ++ dev->keep_format = val; ++ try_free_buffers( ++ dev); /* will only free buffers if !keep_format */ ++ break; ++ case CID_SUSTAIN_FRAMERATE: ++ if (val < 0 || val > 1) ++ return -EINVAL; ++ spin_lock_bh(&dev->lock); ++ dev->sustain_framerate = val; ++ check_timers(dev); ++ spin_unlock_bh(&dev->lock); ++ break; ++ case CID_TIMEOUT: ++ if (val < 0 || val > MAX_TIMEOUT) ++ return -EINVAL; ++ spin_lock_bh(&dev->lock); ++ dev->timeout_jiffies = msecs_to_jiffies(val); ++ check_timers(dev); ++ spin_unlock_bh(&dev->lock); ++ allocate_timeout_image(dev); ++ break; ++ case CID_TIMEOUT_IMAGE_IO: ++ dev->timeout_image_io = 1; ++ break; ++ default: ++ return -EINVAL; ++ } ++ return 0; ++} ++ ++static int v4l2loopback_s_ctrl(struct v4l2_ctrl *ctrl) ++{ ++ struct v4l2_loopback_device *dev = container_of( ++ ctrl->handler, struct v4l2_loopback_device, ctrl_handler); ++ return v4l2loopback_set_ctrl(dev, ctrl->id, ctrl->val); ++} ++ ++/* returns set of device outputs, in our case there is only one ++ * called on VIDIOC_ENUMOUTPUT ++ */ ++static int vidioc_enum_output(struct file *file, void *fh, ++ struct v4l2_output *outp) ++{ ++ __u32 index = outp->index; ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ MARK(); ++ ++ if (!dev->announce_all_caps && !dev->ready_for_output) ++ return -ENOTTY; ++ ++ if (0 != index) ++ return -EINVAL; ++ ++ /* clear all data (including the reserved fields) */ ++ memset(outp, 0, sizeof(*outp)); ++ ++ outp->index = index; ++ strscpy(outp->name, "loopback in", sizeof(outp->name)); ++ outp->type = V4L2_OUTPUT_TYPE_ANALOG; ++ outp->audioset = 0; ++ outp->modulator = 0; ++#ifdef V4L2LOOPBACK_WITH_STD ++ outp->std = V4L2_STD_ALL; ++#ifdef V4L2_OUT_CAP_STD ++ outp->capabilities |= V4L2_OUT_CAP_STD; ++#endif /* V4L2_OUT_CAP_STD */ ++#endif /* V4L2LOOPBACK_WITH_STD */ ++ ++ return 0; ++} ++ ++/* which output is currently active, ++ * called on VIDIOC_G_OUTPUT ++ */ ++static int vidioc_g_output(struct file *file, void *fh, unsigned int *i) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ if (!dev->announce_all_caps && !dev->ready_for_output) ++ return -ENOTTY; ++ if (i) ++ *i = 0; ++ return 0; ++} ++ ++/* set output, can make sense if we have more than one video src, ++ * called on VIDIOC_S_OUTPUT ++ */ ++static int vidioc_s_output(struct file *file, void *fh, unsigned int i) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ if (!dev->announce_all_caps && !dev->ready_for_output) ++ return -ENOTTY; ++ ++ if (i) ++ return -EINVAL; ++ ++ return 0; ++} ++ ++/* returns set of device inputs, in our case there is only one, ++ * but later I may add more ++ * called on VIDIOC_ENUMINPUT ++ */ ++static int vidioc_enum_input(struct file *file, void *fh, ++ struct v4l2_input *inp) ++{ ++ struct v4l2_loopback_device *dev; ++ __u32 index = inp->index; ++ MARK(); ++ ++ if (0 != index) ++ return -EINVAL; ++ ++ /* clear all data (including the reserved fields) */ ++ memset(inp, 0, sizeof(*inp)); ++ ++ inp->index = index; ++ strscpy(inp->name, "loopback", sizeof(inp->name)); ++ inp->type = V4L2_INPUT_TYPE_CAMERA; ++ inp->audioset = 0; ++ inp->tuner = 0; ++ inp->status = 0; ++ ++#ifdef V4L2LOOPBACK_WITH_STD ++ inp->std = V4L2_STD_ALL; ++#ifdef V4L2_IN_CAP_STD ++ inp->capabilities |= V4L2_IN_CAP_STD; ++#endif ++#endif /* V4L2LOOPBACK_WITH_STD */ ++ ++ dev = v4l2loopback_getdevice(file); ++ if (!dev->ready_for_capture) { ++ inp->status |= V4L2_IN_ST_NO_SIGNAL; ++ } ++ ++ return 0; ++} ++ ++/* which input is currently active, ++ * called on VIDIOC_G_INPUT ++ */ ++static int vidioc_g_input(struct file *file, void *fh, unsigned int *i) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ if (!dev->announce_all_caps && !dev->ready_for_capture) ++ return -ENOTTY; ++ if (i) ++ *i = 0; ++ return 0; ++} ++ ++/* set input, can make sense if we have more than one video src, ++ * called on VIDIOC_S_INPUT ++ */ ++static int vidioc_s_input(struct file *file, void *fh, unsigned int i) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ if (!dev->announce_all_caps && !dev->ready_for_capture) ++ return -ENOTTY; ++ if (i == 0) ++ return 0; ++ return -EINVAL; ++} ++ ++/* --------------- V4L2 ioctl buffer related calls ----------------- */ ++ ++/* negotiate buffer type ++ * only mmap streaming supported ++ * called on VIDIOC_REQBUFS ++ */ ++static int vidioc_reqbufs(struct file *file, void *fh, ++ struct v4l2_requestbuffers *b) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ int i; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(fh); ++ ++ dprintk("reqbufs: %d\t%d=%d\n", b->memory, b->count, ++ dev->buffers_number); ++ ++ if (opener->timeout_image_io) { ++ dev->timeout_image_io = 0; ++ if (b->memory != V4L2_MEMORY_MMAP) ++ return -EINVAL; ++ b->count = 2; ++ return 0; ++ } ++ ++ if (V4L2_TYPE_IS_OUTPUT(b->type) && (!dev->ready_for_output)) { ++ return -EBUSY; ++ } ++ ++ init_buffers(dev); ++ switch (b->memory) { ++ case V4L2_MEMORY_MMAP: ++ /* do nothing here, buffers are always allocated */ ++ if (b->count < 1 || dev->buffers_number < 1) ++ return 0; ++ ++ if (b->count > dev->buffers_number) ++ b->count = dev->buffers_number; ++ ++ /* make sure that outbufs_list contains buffers from 0 to used_buffers-1 ++ * actually, it will have been already populated via v4l2_loopback_init() ++ * at this point */ ++ if (list_empty(&dev->outbufs_list)) { ++ for (i = 0; i < dev->used_buffers; ++i) ++ list_add_tail(&dev->buffers[i].list_head, ++ &dev->outbufs_list); ++ } ++ ++ /* also, if dev->used_buffers is going to be decreased, we should remove ++ * out-of-range buffers from outbufs_list, and fix bufpos2index mapping */ ++ if (b->count < dev->used_buffers) { ++ struct v4l2l_buffer *pos, *n; ++ ++ list_for_each_entry_safe(pos, n, &dev->outbufs_list, ++ list_head) { ++ if (pos->buffer.index >= b->count) ++ list_del(&pos->list_head); ++ } ++ ++ /* after we update dev->used_buffers, buffers in outbufs_list will ++ * correspond to dev->write_position + [0;b->count-1] range */ ++ i = v4l2l_mod64(dev->write_position, b->count); ++ list_for_each_entry(pos, &dev->outbufs_list, ++ list_head) { ++ dev->bufpos2index[i % b->count] = ++ pos->buffer.index; ++ ++i; ++ } ++ } ++ ++ opener->buffers_number = b->count; ++ if (opener->buffers_number < dev->used_buffers) ++ dev->used_buffers = opener->buffers_number; ++ return 0; ++ default: ++ return -EINVAL; ++ } ++} ++ ++/* returns buffer asked for; ++ * give app as many buffers as it wants, if it less than MAX, ++ * but map them in our inner buffers ++ * called on VIDIOC_QUERYBUF ++ */ ++static int vidioc_querybuf(struct file *file, void *fh, struct v4l2_buffer *b) ++{ ++ enum v4l2_buf_type type; ++ int index; ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ ++ MARK(); ++ ++ type = b->type; ++ index = b->index; ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(fh); ++ ++ if ((b->type != V4L2_BUF_TYPE_VIDEO_CAPTURE) && ++ (b->type != V4L2_BUF_TYPE_VIDEO_OUTPUT)) { ++ return -EINVAL; ++ } ++ if (b->index > max_buffers) ++ return -EINVAL; ++ ++ if (opener->timeout_image_io) ++ *b = dev->timeout_image_buffer.buffer; ++ else ++ *b = dev->buffers[b->index % dev->used_buffers].buffer; ++ ++ b->type = type; ++ b->index = index; ++ dprintkrw("buffer type: %d (of %d with size=%ld)\n", b->memory, ++ dev->buffers_number, dev->buffer_size); ++ ++ /* Hopefully fix 'DQBUF return bad index if queue bigger then 2 for capture' ++ https://github.com/umlaeute/v4l2loopback/issues/60 */ ++ b->flags &= ~V4L2_BUF_FLAG_DONE; ++ b->flags |= V4L2_BUF_FLAG_QUEUED; ++ ++ return 0; ++} ++ ++static void buffer_written(struct v4l2_loopback_device *dev, ++ struct v4l2l_buffer *buf) ++{ ++ del_timer_sync(&dev->sustain_timer); ++ del_timer_sync(&dev->timeout_timer); ++ ++ spin_lock_bh(&dev->list_lock); ++ list_move_tail(&buf->list_head, &dev->outbufs_list); ++ spin_unlock_bh(&dev->list_lock); ++ ++ spin_lock_bh(&dev->lock); ++ dev->bufpos2index[v4l2l_mod64(dev->write_position, dev->used_buffers)] = ++ buf->buffer.index; ++ ++dev->write_position; ++ dev->reread_count = 0; ++ ++ check_timers(dev); ++ spin_unlock_bh(&dev->lock); ++} ++ ++/* put buffer to queue ++ * called on VIDIOC_QBUF ++ */ ++static int vidioc_qbuf(struct file *file, void *fh, struct v4l2_buffer *buf) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ struct v4l2l_buffer *b; ++ int index; ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(fh); ++ ++ if (buf->index > max_buffers) ++ return -EINVAL; ++ if (opener->timeout_image_io) ++ return 0; ++ ++ index = buf->index % dev->used_buffers; ++ b = &dev->buffers[index]; ++ ++ switch (buf->type) { ++ case V4L2_BUF_TYPE_VIDEO_CAPTURE: ++ dprintkrw( ++ "qbuf(CAPTURE)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", ++ index, buf->index, buf, buf->type, buf->bytesused, ++ buf->length, buf->flags, buf->field, ++ (long long)buf->timestamp.tv_sec, ++ (long int)buf->timestamp.tv_usec, buf->sequence); ++ set_queued(b); ++ return 0; ++ case V4L2_BUF_TYPE_VIDEO_OUTPUT: ++ dprintkrw( ++ "qbuf(OUTPUT)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", ++ index, buf->index, buf, buf->type, buf->bytesused, ++ buf->length, buf->flags, buf->field, ++ (long long)buf->timestamp.tv_sec, ++ (long int)buf->timestamp.tv_usec, buf->sequence); ++ if ((!(b->buffer.flags & V4L2_BUF_FLAG_TIMESTAMP_COPY)) && ++ (buf->timestamp.tv_sec == 0 && buf->timestamp.tv_usec == 0)) ++ v4l2l_get_timestamp(&b->buffer); ++ else { ++ b->buffer.timestamp = buf->timestamp; ++ b->buffer.flags |= V4L2_BUF_FLAG_TIMESTAMP_COPY; ++ } ++ if (dev->pix_format_has_valid_sizeimage) { ++ if (buf->bytesused >= dev->pix_format.sizeimage) { ++ b->buffer.bytesused = dev->pix_format.sizeimage; ++ } else { ++#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 5, 0) ++ dev_warn_ratelimited( ++ &dev->vdev->dev, ++#else ++ dprintkrw( ++#endif ++ "warning queued output buffer bytesused too small %d < %d\n", ++ buf->bytesused, ++ dev->pix_format.sizeimage); ++ b->buffer.bytesused = buf->bytesused; ++ } ++ } else { ++ b->buffer.bytesused = buf->bytesused; ++ } ++ ++ set_done(b); ++ buffer_written(dev, b); ++ ++ /* Hopefully fix 'DQBUF return bad index if queue bigger then 2 for capture' ++ https://github.com/umlaeute/v4l2loopback/issues/60 */ ++ buf->flags &= ~V4L2_BUF_FLAG_DONE; ++ buf->flags |= V4L2_BUF_FLAG_QUEUED; ++ ++ wake_up_all(&dev->read_event); ++ return 0; ++ default: ++ return -EINVAL; ++ } ++} ++ ++static int can_read(struct v4l2_loopback_device *dev, ++ struct v4l2_loopback_opener *opener) ++{ ++ int ret; ++ ++ spin_lock_bh(&dev->lock); ++ check_timers(dev); ++ ret = dev->write_position > opener->read_position || ++ dev->reread_count > opener->reread_count || dev->timeout_happened; ++ spin_unlock_bh(&dev->lock); ++ return ret; ++} ++ ++static int get_capture_buffer(struct file *file) ++{ ++ struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); ++ struct v4l2_loopback_opener *opener = fh_to_opener(file->private_data); ++ int pos, ret; ++ int timeout_happened; ++ ++ if ((file->f_flags & O_NONBLOCK) && ++ (dev->write_position <= opener->read_position && ++ dev->reread_count <= opener->reread_count && ++ !dev->timeout_happened)) ++ return -EAGAIN; ++ wait_event_interruptible(dev->read_event, can_read(dev, opener)); ++ ++ spin_lock_bh(&dev->lock); ++ if (dev->write_position == opener->read_position) { ++ if (dev->reread_count > opener->reread_count + 2) ++ opener->reread_count = dev->reread_count - 1; ++ ++opener->reread_count; ++ pos = v4l2l_mod64(opener->read_position + dev->used_buffers - 1, ++ dev->used_buffers); ++ } else { ++ opener->reread_count = 0; ++ if (dev->write_position > ++ opener->read_position + dev->used_buffers) ++ opener->read_position = dev->write_position - 1; ++ pos = v4l2l_mod64(opener->read_position, dev->used_buffers); ++ ++opener->read_position; ++ } ++ timeout_happened = dev->timeout_happened; ++ dev->timeout_happened = 0; ++ spin_unlock_bh(&dev->lock); ++ ++ ret = dev->bufpos2index[pos]; ++ if (timeout_happened) { ++ if (ret < 0) { ++ dprintk("trying to return not mapped buf[%d]\n", ret); ++ return -EFAULT; ++ } ++ /* although allocated on-demand, timeout_image is freed only ++ * in free_buffers(), so we don't need to worry about it being ++ * deallocated suddenly */ ++ memcpy(dev->image + dev->buffers[ret].buffer.m.offset, ++ dev->timeout_image, dev->buffer_size); ++ } ++ return ret; ++} ++ ++/* put buffer to dequeue ++ * called on VIDIOC_DQBUF ++ */ ++static int vidioc_dqbuf(struct file *file, void *fh, struct v4l2_buffer *buf) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ int index; ++ struct v4l2l_buffer *b; ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(fh); ++ if (opener->timeout_image_io) { ++ *buf = dev->timeout_image_buffer.buffer; ++ return 0; ++ } ++ ++ switch (buf->type) { ++ case V4L2_BUF_TYPE_VIDEO_CAPTURE: ++ index = get_capture_buffer(file); ++ if (index < 0) ++ return index; ++ dprintkrw("capture DQBUF pos: %lld index: %d\n", ++ (long long)(opener->read_position - 1), index); ++ if (!(dev->buffers[index].buffer.flags & ++ V4L2_BUF_FLAG_MAPPED)) { ++ dprintk("trying to return not mapped buf[%d]\n", index); ++ return -EINVAL; ++ } ++ unset_flags(&dev->buffers[index]); ++ *buf = dev->buffers[index].buffer; ++ dprintkrw( ++ "dqbuf(CAPTURE)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", ++ index, buf->index, buf, buf->type, buf->bytesused, ++ buf->length, buf->flags, buf->field, ++ (long long)buf->timestamp.tv_sec, ++ (long int)buf->timestamp.tv_usec, buf->sequence); ++ return 0; ++ case V4L2_BUF_TYPE_VIDEO_OUTPUT: ++ spin_lock_bh(&dev->list_lock); ++ ++ b = list_entry(dev->outbufs_list.prev, struct v4l2l_buffer, ++ list_head); ++ list_move_tail(&b->list_head, &dev->outbufs_list); ++ ++ spin_unlock_bh(&dev->list_lock); ++ dprintkrw("output DQBUF index: %d\n", b->buffer.index); ++ unset_flags(b); ++ *buf = b->buffer; ++ buf->type = V4L2_BUF_TYPE_VIDEO_OUTPUT; ++ dprintkrw( ++ "dqbuf(OUTPUT)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", ++ index, buf->index, buf, buf->type, buf->bytesused, ++ buf->length, buf->flags, buf->field, ++ (long long)buf->timestamp.tv_sec, ++ (long int)buf->timestamp.tv_usec, buf->sequence); ++ return 0; ++ default: ++ return -EINVAL; ++ } ++} ++ ++/* ------------- STREAMING ------------------- */ ++ ++/* start streaming ++ * called on VIDIOC_STREAMON ++ */ ++static int vidioc_streamon(struct file *file, void *fh, enum v4l2_buf_type type) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(fh); ++ ++ switch (type) { ++ case V4L2_BUF_TYPE_VIDEO_OUTPUT: ++ if (!dev->ready_for_capture) { ++ int ret = allocate_buffers(dev); ++ if (ret < 0) ++ return ret; ++ } ++ opener->type = WRITER; ++ dev->ready_for_output = 0; ++ dev->ready_for_capture++; ++ return 0; ++ case V4L2_BUF_TYPE_VIDEO_CAPTURE: ++ if (!dev->ready_for_capture) ++ return -EIO; ++ if (dev->active_readers > 0) ++ return -EBUSY; ++ opener->type = READER; ++ dev->active_readers++; ++ client_usage_queue_event(dev->vdev); ++ return 0; ++ default: ++ return -EINVAL; ++ } ++ return -EINVAL; ++} ++ ++/* stop streaming ++ * called on VIDIOC_STREAMOFF ++ */ ++static int vidioc_streamoff(struct file *file, void *fh, ++ enum v4l2_buf_type type) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ ++ MARK(); ++ dprintk("%d\n", type); ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(fh); ++ switch (type) { ++ case V4L2_BUF_TYPE_VIDEO_OUTPUT: ++ if (dev->ready_for_capture > 0) ++ dev->ready_for_capture--; ++ return 0; ++ case V4L2_BUF_TYPE_VIDEO_CAPTURE: ++ if (opener->type == READER) { ++ opener->type = 0; ++ dev->active_readers--; ++ client_usage_queue_event(dev->vdev); ++ } ++ return 0; ++ default: ++ return -EINVAL; ++ } ++ return -EINVAL; ++} ++ ++#ifdef CONFIG_VIDEO_V4L1_COMPAT ++static int vidiocgmbuf(struct file *file, void *fh, struct video_mbuf *p) ++{ ++ struct v4l2_loopback_device *dev; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ p->frames = dev->buffers_number; ++ p->offsets[0] = 0; ++ p->offsets[1] = 0; ++ p->size = dev->buffer_size; ++ return 0; ++} ++#endif ++ ++static void client_usage_queue_event(struct video_device *vdev) ++{ ++ struct v4l2_event ev; ++ struct v4l2_loopback_device *dev; ++ ++ dev = container_of(vdev->v4l2_dev, struct v4l2_loopback_device, ++ v4l2_dev); ++ ++ memset(&ev, 0, sizeof(ev)); ++ ev.type = V4L2_EVENT_PRI_CLIENT_USAGE; ++ ((struct v4l2_event_client_usage *)&ev.u)->count = dev->active_readers; ++ ++ v4l2_event_queue(vdev, &ev); ++} ++ ++static int client_usage_ops_add(struct v4l2_subscribed_event *sev, ++ unsigned elems) ++{ ++ if (!(sev->flags & V4L2_EVENT_SUB_FL_SEND_INITIAL)) ++ return 0; ++ ++ client_usage_queue_event(sev->fh->vdev); ++ return 0; ++} ++ ++static void client_usage_ops_replace(struct v4l2_event *old, ++ const struct v4l2_event *new) ++{ ++ *((struct v4l2_event_client_usage *)&old->u) = ++ *((struct v4l2_event_client_usage *)&new->u); ++} ++ ++static void client_usage_ops_merge(const struct v4l2_event *old, ++ struct v4l2_event *new) ++{ ++ *((struct v4l2_event_client_usage *)&new->u) = ++ *((struct v4l2_event_client_usage *)&old->u); ++} ++ ++const struct v4l2_subscribed_event_ops client_usage_ops = { ++ .add = client_usage_ops_add, ++ .replace = client_usage_ops_replace, ++ .merge = client_usage_ops_merge, ++}; ++ ++static int vidioc_subscribe_event(struct v4l2_fh *fh, ++ const struct v4l2_event_subscription *sub) ++{ ++ switch (sub->type) { ++ case V4L2_EVENT_CTRL: ++ return v4l2_ctrl_subscribe_event(fh, sub); ++ case V4L2_EVENT_PRI_CLIENT_USAGE: ++ return v4l2_event_subscribe(fh, sub, 0, &client_usage_ops); ++ } ++ ++ return -EINVAL; ++} ++ ++/* file operations */ ++static void vm_open(struct vm_area_struct *vma) ++{ ++ struct v4l2l_buffer *buf; ++ MARK(); ++ ++ buf = vma->vm_private_data; ++ buf->use_count++; ++ ++ buf->buffer.flags |= V4L2_BUF_FLAG_MAPPED; ++} ++ ++static void vm_close(struct vm_area_struct *vma) ++{ ++ struct v4l2l_buffer *buf; ++ MARK(); ++ ++ buf = vma->vm_private_data; ++ buf->use_count--; ++ ++ if (buf->use_count <= 0) ++ buf->buffer.flags &= ~V4L2_BUF_FLAG_MAPPED; ++} ++ ++static struct vm_operations_struct vm_ops = { ++ .open = vm_open, ++ .close = vm_close, ++}; ++ ++static int v4l2_loopback_mmap(struct file *file, struct vm_area_struct *vma) ++{ ++ u8 *addr; ++ unsigned long start; ++ unsigned long size; ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ struct v4l2l_buffer *buffer = NULL; ++ MARK(); ++ ++ start = (unsigned long)vma->vm_start; ++ size = (unsigned long)(vma->vm_end - vma->vm_start); ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(file->private_data); ++ ++ if (size > dev->buffer_size) { ++ dprintk("userspace tries to mmap too much, fail\n"); ++ return -EINVAL; ++ } ++ if (opener->timeout_image_io) { ++ /* we are going to map the timeout_image_buffer */ ++ if ((vma->vm_pgoff << PAGE_SHIFT) != ++ dev->buffer_size * MAX_BUFFERS) { ++ dprintk("invalid mmap offset for timeout_image_io mode\n"); ++ return -EINVAL; ++ } ++ } else if ((vma->vm_pgoff << PAGE_SHIFT) > ++ dev->buffer_size * (dev->buffers_number - 1)) { ++ dprintk("userspace tries to mmap too far, fail\n"); ++ return -EINVAL; ++ } ++ ++ /* FIXXXXXME: allocation should not happen here! */ ++ if (NULL == dev->image) ++ if (allocate_buffers(dev) < 0) ++ return -EINVAL; ++ ++ if (opener->timeout_image_io) { ++ buffer = &dev->timeout_image_buffer; ++ addr = dev->timeout_image; ++ } else { ++ int i; ++ for (i = 0; i < dev->buffers_number; ++i) { ++ buffer = &dev->buffers[i]; ++ if ((buffer->buffer.m.offset >> PAGE_SHIFT) == ++ vma->vm_pgoff) ++ break; ++ } ++ ++ if (i >= dev->buffers_number) ++ return -EINVAL; ++ ++ addr = dev->image + (vma->vm_pgoff << PAGE_SHIFT); ++ } ++ ++ while (size > 0) { ++ struct page *page; ++ ++ page = vmalloc_to_page(addr); ++ ++ if (vm_insert_page(vma, start, page) < 0) ++ return -EAGAIN; ++ ++ start += PAGE_SIZE; ++ addr += PAGE_SIZE; ++ size -= PAGE_SIZE; ++ } ++ ++ vma->vm_ops = &vm_ops; ++ vma->vm_private_data = buffer; ++ ++ vm_open(vma); ++ ++ MARK(); ++ return 0; ++} ++ ++static unsigned int v4l2_loopback_poll(struct file *file, ++ struct poll_table_struct *pts) ++{ ++ struct v4l2_loopback_opener *opener; ++ struct v4l2_loopback_device *dev; ++ __poll_t req_events = poll_requested_events(pts); ++ int ret_mask = 0; ++ MARK(); ++ ++ opener = fh_to_opener(file->private_data); ++ dev = v4l2loopback_getdevice(file); ++ ++ if (req_events & POLLPRI) { ++ if (!v4l2_event_pending(&opener->fh)) ++ poll_wait(file, &opener->fh.wait, pts); ++ if (v4l2_event_pending(&opener->fh)) { ++ ret_mask |= POLLPRI; ++ if (!(req_events & DEFAULT_POLLMASK)) ++ return ret_mask; ++ } ++ } ++ ++ switch (opener->type) { ++ case WRITER: ++ ret_mask |= POLLOUT | POLLWRNORM; ++ break; ++ case READER: ++ if (!can_read(dev, opener)) { ++ if (ret_mask) ++ return ret_mask; ++ poll_wait(file, &dev->read_event, pts); ++ } ++ if (can_read(dev, opener)) ++ ret_mask |= POLLIN | POLLRDNORM; ++ if (v4l2_event_pending(&opener->fh)) ++ ret_mask |= POLLPRI; ++ break; ++ default: ++ break; ++ } ++ ++ MARK(); ++ return ret_mask; ++} ++ ++/* do not want to limit device opens, it can be as many readers as user want, ++ * writers are limited by means of setting writer field */ ++static int v4l2_loopback_open(struct file *file) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_opener *opener; ++ MARK(); ++ dev = v4l2loopback_getdevice(file); ++ if (dev->open_count.counter >= dev->max_openers) ++ return -EBUSY; ++ /* kfree on close */ ++ opener = kzalloc(sizeof(*opener), GFP_KERNEL); ++ if (opener == NULL) ++ return -ENOMEM; ++ ++ atomic_inc(&dev->open_count); ++ ++ opener->timeout_image_io = dev->timeout_image_io; ++ if (opener->timeout_image_io) { ++ int r = allocate_timeout_image(dev); ++ ++ if (r < 0) { ++ dprintk("timeout image allocation failed\n"); ++ ++ atomic_dec(&dev->open_count); ++ ++ kfree(opener); ++ return r; ++ } ++ } ++ ++ v4l2_fh_init(&opener->fh, video_devdata(file)); ++ file->private_data = &opener->fh; ++ ++ v4l2_fh_add(&opener->fh); ++ dprintk("opened dev:%p with image:%p\n", dev, dev ? dev->image : NULL); ++ MARK(); ++ return 0; ++} ++ ++static int v4l2_loopback_close(struct file *file) ++{ ++ struct v4l2_loopback_opener *opener; ++ struct v4l2_loopback_device *dev; ++ int is_writer = 0, is_reader = 0; ++ MARK(); ++ ++ opener = fh_to_opener(file->private_data); ++ dev = v4l2loopback_getdevice(file); ++ ++ if (WRITER == opener->type) ++ is_writer = 1; ++ if (READER == opener->type) ++ is_reader = 1; ++ ++ atomic_dec(&dev->open_count); ++ if (dev->open_count.counter == 0) { ++ del_timer_sync(&dev->sustain_timer); ++ del_timer_sync(&dev->timeout_timer); ++ } ++ try_free_buffers(dev); ++ ++ v4l2_fh_del(&opener->fh); ++ v4l2_fh_exit(&opener->fh); ++ ++ kfree(opener); ++ if (is_writer) ++ dev->ready_for_output = 1; ++ if (is_reader) { ++ dev->active_readers--; ++ client_usage_queue_event(dev->vdev); ++ } ++ MARK(); ++ return 0; ++} ++ ++static ssize_t v4l2_loopback_read(struct file *file, char __user *buf, ++ size_t count, loff_t *ppos) ++{ ++ int read_index; ++ struct v4l2_loopback_device *dev; ++ struct v4l2_buffer *b; ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ ++ read_index = get_capture_buffer(file); ++ if (read_index < 0) ++ return read_index; ++ if (count > dev->buffer_size) ++ count = dev->buffer_size; ++ b = &dev->buffers[read_index].buffer; ++ if (count > b->bytesused) ++ count = b->bytesused; ++ if (copy_to_user((void *)buf, (void *)(dev->image + b->m.offset), ++ count)) { ++ printk(KERN_ERR ++ "v4l2-loopback: failed copy_to_user() in read buf\n"); ++ return -EFAULT; ++ } ++ dprintkrw("leave v4l2_loopback_read()\n"); ++ return count; ++} ++ ++static ssize_t v4l2_loopback_write(struct file *file, const char __user *buf, ++ size_t count, loff_t *ppos) ++{ ++ struct v4l2_loopback_opener *opener; ++ struct v4l2_loopback_device *dev; ++ int write_index; ++ struct v4l2_buffer *b; ++ int err = 0; ++ ++ MARK(); ++ ++ dev = v4l2loopback_getdevice(file); ++ opener = fh_to_opener(file->private_data); ++ ++ if (UNNEGOTIATED == opener->type) { ++ spin_lock(&dev->lock); ++ ++ if (dev->ready_for_output) { ++ err = vidioc_streamon(file, file->private_data, ++ V4L2_BUF_TYPE_VIDEO_OUTPUT); ++ } ++ ++ spin_unlock(&dev->lock); ++ ++ if (err < 0) ++ return err; ++ } ++ ++ if (WRITER != opener->type) ++ return -EINVAL; ++ ++ if (!dev->ready_for_capture) { ++ int ret = allocate_buffers(dev); ++ if (ret < 0) ++ return ret; ++ dev->ready_for_capture = 1; ++ } ++ dprintkrw("v4l2_loopback_write() trying to write %zu bytes\n", count); ++ if (count > dev->buffer_size) ++ count = dev->buffer_size; ++ ++ write_index = v4l2l_mod64(dev->write_position, dev->used_buffers); ++ b = &dev->buffers[write_index].buffer; ++ ++ if (copy_from_user((void *)(dev->image + b->m.offset), (void *)buf, ++ count)) { ++ printk(KERN_ERR ++ "v4l2-loopback: failed copy_from_user() in write buf, could not write %zu\n", ++ count); ++ return -EFAULT; ++ } ++ v4l2l_get_timestamp(b); ++ b->bytesused = count; ++ b->sequence = dev->write_position; ++ buffer_written(dev, &dev->buffers[write_index]); ++ wake_up_all(&dev->read_event); ++ dprintkrw("leave v4l2_loopback_write()\n"); ++ return count; ++} ++ ++/* init functions */ ++/* frees buffers, if already allocated */ ++static void free_buffers(struct v4l2_loopback_device *dev) ++{ ++ MARK(); ++ dprintk("freeing image@%p for dev:%p\n", dev ? dev->image : NULL, dev); ++ if (!dev) ++ return; ++ if (dev->image) { ++ vfree(dev->image); ++ dev->image = NULL; ++ } ++ if (dev->timeout_image) { ++ vfree(dev->timeout_image); ++ dev->timeout_image = NULL; ++ } ++ dev->imagesize = 0; ++} ++/* frees buffers, if they are no longer needed */ ++static void try_free_buffers(struct v4l2_loopback_device *dev) ++{ ++ MARK(); ++ if (0 == dev->open_count.counter && !dev->keep_format) { ++ free_buffers(dev); ++ dev->ready_for_capture = 0; ++ dev->buffer_size = 0; ++ dev->write_position = 0; ++ } ++} ++/* allocates buffers, if buffer_size is set */ ++static int allocate_buffers(struct v4l2_loopback_device *dev) ++{ ++ int err; ++ ++ MARK(); ++ /* vfree on close file operation in case no open handles left */ ++ ++ if (dev->buffer_size < 1 || dev->buffers_number < 1) ++ return -EINVAL; ++ ++ if ((__LONG_MAX__ / dev->buffer_size) < dev->buffers_number) ++ return -ENOSPC; ++ ++ if (dev->image) { ++ dprintk("allocating buffers again: %ld %ld\n", ++ dev->buffer_size * dev->buffers_number, dev->imagesize); ++ /* FIXME: prevent double allocation more intelligently! */ ++ if (dev->buffer_size * dev->buffers_number == dev->imagesize) ++ return 0; ++ ++ /* check whether the total number of readers/writers is <=1 */ ++ if ((dev->ready_for_capture + dev->active_readers) <= 1) ++ free_buffers(dev); ++ else ++ return -EINVAL; ++ } ++ ++ dev->imagesize = (unsigned long)dev->buffer_size * ++ (unsigned long)dev->buffers_number; ++ ++ dprintk("allocating %ld = %ldx%d\n", dev->imagesize, dev->buffer_size, ++ dev->buffers_number); ++ err = -ENOMEM; ++ ++ if (dev->timeout_jiffies > 0) { ++ err = allocate_timeout_image(dev); ++ if (err < 0) ++ goto error; ++ } ++ ++ dev->image = vmalloc(dev->imagesize); ++ if (dev->image == NULL) ++ goto error; ++ ++ dprintk("vmallocated %ld bytes\n", dev->imagesize); ++ MARK(); ++ ++ init_buffers(dev); ++ return 0; ++ ++error: ++ free_buffers(dev); ++ return err; ++} ++ ++/* init inner buffers, they are capture mode and flags are set as ++ * for capture mod buffers */ ++static void init_buffers(struct v4l2_loopback_device *dev) ++{ ++ int i; ++ int buffer_size; ++ int bytesused; ++ MARK(); ++ ++ buffer_size = dev->buffer_size; ++ bytesused = dev->pix_format.sizeimage; ++ for (i = 0; i < dev->buffers_number; ++i) { ++ struct v4l2_buffer *b = &dev->buffers[i].buffer; ++ b->index = i; ++ b->bytesused = bytesused; ++ b->length = buffer_size; ++ b->field = V4L2_FIELD_NONE; ++ b->flags = 0; ++ b->m.offset = i * buffer_size; ++ b->memory = V4L2_MEMORY_MMAP; ++ b->sequence = 0; ++ b->timestamp.tv_sec = 0; ++ b->timestamp.tv_usec = 0; ++ b->type = V4L2_BUF_TYPE_VIDEO_CAPTURE; ++ ++ v4l2l_get_timestamp(b); ++ } ++ dev->timeout_image_buffer = dev->buffers[0]; ++ dev->timeout_image_buffer.buffer.m.offset = MAX_BUFFERS * buffer_size; ++ MARK(); ++} ++ ++static int allocate_timeout_image(struct v4l2_loopback_device *dev) ++{ ++ MARK(); ++ if (dev->buffer_size <= 0) { ++ dev->timeout_image_io = 0; ++ return -EINVAL; ++ } ++ ++ if (dev->timeout_image == NULL) { ++ dev->timeout_image = vzalloc(dev->buffer_size); ++ if (dev->timeout_image == NULL) { ++ dev->timeout_image_io = 0; ++ return -ENOMEM; ++ } ++ } ++ return 0; ++} ++ ++/* fills and register video device */ ++static void init_vdev(struct video_device *vdev, int nr) ++{ ++ MARK(); ++ ++#ifdef V4L2LOOPBACK_WITH_STD ++ vdev->tvnorms = V4L2_STD_ALL; ++#endif /* V4L2LOOPBACK_WITH_STD */ ++ ++ vdev->vfl_type = VFL_TYPE_VIDEO; ++ vdev->fops = &v4l2_loopback_fops; ++ vdev->ioctl_ops = &v4l2_loopback_ioctl_ops; ++ vdev->release = &video_device_release; ++ vdev->minor = -1; ++#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 7, 0) ++ vdev->device_caps = V4L2_CAP_DEVICE_CAPS | V4L2_CAP_VIDEO_CAPTURE | ++ V4L2_CAP_VIDEO_OUTPUT | V4L2_CAP_READWRITE | ++ V4L2_CAP_STREAMING; ++#endif ++ ++ if (debug > 1) ++ vdev->dev_debug = V4L2_DEV_DEBUG_IOCTL | ++ V4L2_DEV_DEBUG_IOCTL_ARG; ++ ++ vdev->vfl_dir = VFL_DIR_M2M; ++ ++ MARK(); ++} ++ ++/* init default capture parameters, only fps may be changed in future */ ++static void init_capture_param(struct v4l2_captureparm *capture_param) ++{ ++ MARK(); ++ capture_param->capability = 0; ++ capture_param->capturemode = 0; ++ capture_param->extendedmode = 0; ++ capture_param->readbuffers = max_buffers; ++ capture_param->timeperframe.numerator = 1; ++ capture_param->timeperframe.denominator = 30; ++} ++ ++static void check_timers(struct v4l2_loopback_device *dev) ++{ ++ if (!dev->ready_for_capture) ++ return; ++ ++ if (dev->timeout_jiffies > 0 && !timer_pending(&dev->timeout_timer)) ++ mod_timer(&dev->timeout_timer, jiffies + dev->timeout_jiffies); ++ if (dev->sustain_framerate && !timer_pending(&dev->sustain_timer)) ++ mod_timer(&dev->sustain_timer, ++ jiffies + dev->frame_jiffies * 3 / 2); ++} ++#ifdef HAVE_TIMER_SETUP ++static void sustain_timer_clb(struct timer_list *t) ++{ ++ struct v4l2_loopback_device *dev = from_timer(dev, t, sustain_timer); ++#else ++static void sustain_timer_clb(unsigned long nr) ++{ ++ struct v4l2_loopback_device *dev = ++ idr_find(&v4l2loopback_index_idr, nr); ++#endif ++ spin_lock(&dev->lock); ++ if (dev->sustain_framerate) { ++ dev->reread_count++; ++ dprintkrw("reread: %lld %d\n", (long long)dev->write_position, ++ dev->reread_count); ++ if (dev->reread_count == 1) ++ mod_timer(&dev->sustain_timer, ++ jiffies + max(1UL, dev->frame_jiffies / 2)); ++ else ++ mod_timer(&dev->sustain_timer, ++ jiffies + dev->frame_jiffies); ++ wake_up_all(&dev->read_event); ++ } ++ spin_unlock(&dev->lock); ++} ++#ifdef HAVE_TIMER_SETUP ++static void timeout_timer_clb(struct timer_list *t) ++{ ++ struct v4l2_loopback_device *dev = from_timer(dev, t, timeout_timer); ++#else ++static void timeout_timer_clb(unsigned long nr) ++{ ++ struct v4l2_loopback_device *dev = ++ idr_find(&v4l2loopback_index_idr, nr); ++#endif ++ spin_lock(&dev->lock); ++ if (dev->timeout_jiffies > 0) { ++ dev->timeout_happened = 1; ++ mod_timer(&dev->timeout_timer, jiffies + dev->timeout_jiffies); ++ wake_up_all(&dev->read_event); ++ } ++ spin_unlock(&dev->lock); ++} ++ ++/* init loopback main structure */ ++#define DEFAULT_FROM_CONF(confmember, default_condition, default_value) \ ++ ((conf) ? \ ++ ((conf->confmember default_condition) ? (default_value) : \ ++ (conf->confmember)) : \ ++ default_value) ++ ++static int v4l2_loopback_add(struct v4l2_loopback_config *conf, int *ret_nr) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_ctrl_handler *hdl; ++ struct v4l2loopback_private *vdev_priv = NULL; ++ ++ int err = -ENOMEM; ++ ++ u32 _width = V4L2LOOPBACK_SIZE_DEFAULT_WIDTH; ++ u32 _height = V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT; ++ ++ u32 _min_width = DEFAULT_FROM_CONF(min_width, ++ < V4L2LOOPBACK_SIZE_MIN_WIDTH, ++ V4L2LOOPBACK_SIZE_MIN_WIDTH); ++ u32 _min_height = DEFAULT_FROM_CONF(min_height, ++ < V4L2LOOPBACK_SIZE_MIN_HEIGHT, ++ V4L2LOOPBACK_SIZE_MIN_HEIGHT); ++ u32 _max_width = DEFAULT_FROM_CONF(max_width, < _min_width, max_width); ++ u32 _max_height = ++ DEFAULT_FROM_CONF(max_height, < _min_height, max_height); ++ bool _announce_all_caps = (conf && conf->announce_all_caps >= 0) ? ++ (conf->announce_all_caps) : ++ V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS; ++ int _max_buffers = DEFAULT_FROM_CONF(max_buffers, <= 0, max_buffers); ++ int _max_openers = DEFAULT_FROM_CONF(max_openers, <= 0, max_openers); ++ ++ int nr = -1; ++ ++ _announce_all_caps = (!!_announce_all_caps); ++ ++ if (conf) { ++ const int output_nr = conf->output_nr; ++#ifdef SPLIT_DEVICES ++ const int capture_nr = conf->capture_nr; ++#else ++ const int capture_nr = output_nr; ++#endif ++ if (capture_nr >= 0 && output_nr == capture_nr) { ++ nr = output_nr; ++ } else if (capture_nr < 0 && output_nr < 0) { ++ nr = -1; ++ } else if (capture_nr < 0) { ++ nr = output_nr; ++ } else if (output_nr < 0) { ++ nr = capture_nr; ++ } else { ++ printk(KERN_ERR ++ "split OUTPUT and CAPTURE devices not yet supported."); ++ printk(KERN_INFO ++ "both devices must have the same number (%d != %d).", ++ output_nr, capture_nr); ++ return -EINVAL; ++ } ++ } ++ ++ if (idr_find(&v4l2loopback_index_idr, nr)) ++ return -EEXIST; ++ ++ dprintk("creating v4l2loopback-device #%d\n", nr); ++ dev = kzalloc(sizeof(*dev), GFP_KERNEL); ++ if (!dev) ++ return -ENOMEM; ++ ++ /* allocate id, if @id >= 0, we're requesting that specific id */ ++ if (nr >= 0) { ++ err = idr_alloc(&v4l2loopback_index_idr, dev, nr, nr + 1, ++ GFP_KERNEL); ++ if (err == -ENOSPC) ++ err = -EEXIST; ++ } else { ++ err = idr_alloc(&v4l2loopback_index_idr, dev, 0, 0, GFP_KERNEL); ++ } ++ if (err < 0) ++ goto out_free_dev; ++ nr = err; ++ err = -ENOMEM; ++ ++ if (conf && conf->card_label[0]) { ++ snprintf(dev->card_label, sizeof(dev->card_label), "%s", ++ conf->card_label); ++ } else { ++ snprintf(dev->card_label, sizeof(dev->card_label), ++ "Dummy video device (0x%04X)", nr); ++ } ++ snprintf(dev->v4l2_dev.name, sizeof(dev->v4l2_dev.name), ++ "v4l2loopback-%03d", nr); ++ ++ err = v4l2_device_register(NULL, &dev->v4l2_dev); ++ if (err) ++ goto out_free_idr; ++ MARK(); ++ ++ dev->vdev = video_device_alloc(); ++ if (dev->vdev == NULL) { ++ err = -ENOMEM; ++ goto out_unregister; ++ } ++ ++ vdev_priv = kzalloc(sizeof(struct v4l2loopback_private), GFP_KERNEL); ++ if (vdev_priv == NULL) { ++ err = -ENOMEM; ++ goto out_unregister; ++ } ++ ++ video_set_drvdata(dev->vdev, vdev_priv); ++ if (video_get_drvdata(dev->vdev) == NULL) { ++ err = -ENOMEM; ++ goto out_unregister; ++ } ++ ++ MARK(); ++ snprintf(dev->vdev->name, sizeof(dev->vdev->name), "%s", ++ dev->card_label); ++ ++ vdev_priv->device_nr = nr; ++ ++ init_vdev(dev->vdev, nr); ++ dev->vdev->v4l2_dev = &dev->v4l2_dev; ++ init_capture_param(&dev->capture_param); ++ err = set_timeperframe(dev, &dev->capture_param.timeperframe); ++ if (err) ++ goto out_unregister; ++ dev->keep_format = 0; ++ dev->sustain_framerate = 0; ++ ++ dev->announce_all_caps = _announce_all_caps; ++ dev->min_width = _min_width; ++ dev->min_height = _min_height; ++ dev->max_width = _max_width; ++ dev->max_height = _max_height; ++ dev->max_openers = _max_openers; ++ dev->buffers_number = dev->used_buffers = _max_buffers; ++ ++ dev->write_position = 0; ++ ++ MARK(); ++ spin_lock_init(&dev->lock); ++ spin_lock_init(&dev->list_lock); ++ INIT_LIST_HEAD(&dev->outbufs_list); ++ if (list_empty(&dev->outbufs_list)) { ++ int i; ++ ++ for (i = 0; i < dev->used_buffers; ++i) ++ list_add_tail(&dev->buffers[i].list_head, ++ &dev->outbufs_list); ++ } ++ memset(dev->bufpos2index, 0, sizeof(dev->bufpos2index)); ++ atomic_set(&dev->open_count, 0); ++ dev->ready_for_capture = 0; ++ dev->ready_for_output = 1; ++ ++ dev->buffer_size = 0; ++ dev->image = NULL; ++ dev->imagesize = 0; ++#ifdef HAVE_TIMER_SETUP ++ timer_setup(&dev->sustain_timer, sustain_timer_clb, 0); ++ timer_setup(&dev->timeout_timer, timeout_timer_clb, 0); ++#else ++ setup_timer(&dev->sustain_timer, sustain_timer_clb, nr); ++ setup_timer(&dev->timeout_timer, timeout_timer_clb, nr); ++#endif ++ dev->reread_count = 0; ++ dev->timeout_jiffies = 0; ++ dev->timeout_image = NULL; ++ dev->timeout_happened = 0; ++ ++ hdl = &dev->ctrl_handler; ++ err = v4l2_ctrl_handler_init(hdl, 4); ++ if (err) ++ goto out_unregister; ++ v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_keepformat, NULL); ++ v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_sustainframerate, NULL); ++ v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_timeout, NULL); ++ v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_timeoutimageio, NULL); ++ if (hdl->error) { ++ err = hdl->error; ++ goto out_free_handler; ++ } ++ dev->v4l2_dev.ctrl_handler = hdl; ++ ++ err = v4l2_ctrl_handler_setup(hdl); ++ if (err) ++ goto out_free_handler; ++ ++ /* FIXME set buffers to 0 */ ++ ++ /* Set initial format */ ++ if (_width < _min_width) ++ _width = _min_width; ++ if (_width > _max_width) ++ _width = _max_width; ++ if (_height < _min_height) ++ _height = _min_height; ++ if (_height > _max_height) ++ _height = _max_height; ++ ++ dev->pix_format.width = _width; ++ dev->pix_format.height = _height; ++ dev->pix_format.pixelformat = formats[0].fourcc; ++ dev->pix_format.colorspace = ++ V4L2_COLORSPACE_DEFAULT; /* do we need to set this ? */ ++ dev->pix_format.field = V4L2_FIELD_NONE; ++ ++ dev->buffer_size = PAGE_ALIGN(dev->pix_format.sizeimage); ++ dprintk("buffer_size = %ld (=%d)\n", dev->buffer_size, ++ dev->pix_format.sizeimage); ++ ++ if (dev->buffer_size && ((err = allocate_buffers(dev)) < 0)) ++ goto out_free_handler; ++ ++ init_waitqueue_head(&dev->read_event); ++ ++ /* register the device -> it creates /dev/video* */ ++ if (video_register_device(dev->vdev, VFL_TYPE_VIDEO, nr) < 0) { ++ printk(KERN_ERR ++ "v4l2loopback: failed video_register_device()\n"); ++ err = -EFAULT; ++ goto out_free_device; ++ } ++ v4l2loopback_create_sysfs(dev->vdev); ++ ++ MARK(); ++ if (ret_nr) ++ *ret_nr = dev->vdev->num; ++ return 0; ++ ++out_free_device: ++ video_device_release(dev->vdev); ++out_free_handler: ++ v4l2_ctrl_handler_free(&dev->ctrl_handler); ++out_unregister: ++ video_set_drvdata(dev->vdev, NULL); ++ if (vdev_priv != NULL) ++ kfree(vdev_priv); ++ v4l2_device_unregister(&dev->v4l2_dev); ++out_free_idr: ++ idr_remove(&v4l2loopback_index_idr, nr); ++out_free_dev: ++ kfree(dev); ++ return err; ++} ++ ++static void v4l2_loopback_remove(struct v4l2_loopback_device *dev) ++{ ++ free_buffers(dev); ++ v4l2loopback_remove_sysfs(dev->vdev); ++ kfree(video_get_drvdata(dev->vdev)); ++ video_unregister_device(dev->vdev); ++ v4l2_device_unregister(&dev->v4l2_dev); ++ v4l2_ctrl_handler_free(&dev->ctrl_handler); ++ kfree(dev); ++} ++ ++static long v4l2loopback_control_ioctl(struct file *file, unsigned int cmd, ++ unsigned long parm) ++{ ++ struct v4l2_loopback_device *dev; ++ struct v4l2_loopback_config conf; ++ struct v4l2_loopback_config *confptr = &conf; ++ int device_nr, capture_nr, output_nr; ++ int ret; ++ ++ ret = mutex_lock_killable(&v4l2loopback_ctl_mutex); ++ if (ret) ++ return ret; ++ ++ ret = -EINVAL; ++ switch (cmd) { ++ default: ++ ret = -ENOSYS; ++ break; ++ /* add a v4l2loopback device (pair), based on the user-provided specs */ ++ case V4L2LOOPBACK_CTL_ADD: ++ if (parm) { ++ if ((ret = copy_from_user(&conf, (void *)parm, ++ sizeof(conf))) < 0) ++ break; ++ } else ++ confptr = NULL; ++ ret = v4l2_loopback_add(confptr, &device_nr); ++ if (ret >= 0) ++ ret = device_nr; ++ break; ++ /* remove a v4l2loopback device (both capture and output) */ ++ case V4L2LOOPBACK_CTL_REMOVE: ++ ret = v4l2loopback_lookup((int)parm, &dev); ++ if (ret >= 0 && dev) { ++ int nr = ret; ++ ret = -EBUSY; ++ if (dev->open_count.counter > 0) ++ break; ++ idr_remove(&v4l2loopback_index_idr, nr); ++ v4l2_loopback_remove(dev); ++ ret = 0; ++ }; ++ break; ++ /* get information for a loopback device. ++ * this is mostly about limits (which cannot be queried directly with VIDIOC_G_FMT and friends ++ */ ++ case V4L2LOOPBACK_CTL_QUERY: ++ if (!parm) ++ break; ++ if ((ret = copy_from_user(&conf, (void *)parm, sizeof(conf))) < ++ 0) ++ break; ++ capture_nr = output_nr = conf.output_nr; ++#ifdef SPLIT_DEVICES ++ capture_nr = conf.capture_nr; ++#endif ++ device_nr = (output_nr < 0) ? capture_nr : output_nr; ++ MARK(); ++ /* get the device from either capture_nr or output_nr (whatever is valid) */ ++ if ((ret = v4l2loopback_lookup(device_nr, &dev)) < 0) ++ break; ++ MARK(); ++ /* if we got the device from output_nr and there is a valid capture_nr, ++ * make sure that both refer to the same device (or bail out) ++ */ ++ if ((device_nr != capture_nr) && (capture_nr >= 0) && ++ ((ret = v4l2loopback_lookup(capture_nr, 0)) < 0)) ++ break; ++ MARK(); ++ /* if otoh, we got the device from capture_nr and there is a valid output_nr, ++ * make sure that both refer to the same device (or bail out) ++ */ ++ if ((device_nr != output_nr) && (output_nr >= 0) && ++ ((ret = v4l2loopback_lookup(output_nr, 0)) < 0)) ++ break; ++ MARK(); ++ ++ /* v4l2_loopback_config identified a single device, so fetch the data */ ++ snprintf(conf.card_label, sizeof(conf.card_label), "%s", ++ dev->card_label); ++ MARK(); ++ conf.output_nr = dev->vdev->num; ++#ifdef SPLIT_DEVICES ++ conf.capture_nr = dev->vdev->num; ++#endif ++ conf.min_width = dev->min_width; ++ conf.min_height = dev->min_height; ++ conf.max_width = dev->max_width; ++ conf.max_height = dev->max_height; ++ conf.announce_all_caps = dev->announce_all_caps; ++ conf.max_buffers = dev->buffers_number; ++ conf.max_openers = dev->max_openers; ++ conf.debug = debug; ++ MARK(); ++ if (copy_to_user((void *)parm, &conf, sizeof(conf))) { ++ ret = -EFAULT; ++ break; ++ } ++ MARK(); ++ ret = 0; ++ ; ++ break; ++ } ++ ++ MARK(); ++ mutex_unlock(&v4l2loopback_ctl_mutex); ++ MARK(); ++ return ret; ++} ++ ++/* LINUX KERNEL */ ++ ++static const struct file_operations v4l2loopback_ctl_fops = { ++ // clang-format off ++ .owner = THIS_MODULE, ++ .open = nonseekable_open, ++ .unlocked_ioctl = v4l2loopback_control_ioctl, ++ .compat_ioctl = v4l2loopback_control_ioctl, ++ .llseek = noop_llseek, ++ // clang-format on ++}; ++ ++static struct miscdevice v4l2loopback_misc = { ++ // clang-format off ++ .minor = MISC_DYNAMIC_MINOR, ++ .name = "v4l2loopback", ++ .fops = &v4l2loopback_ctl_fops, ++ // clang-format on ++}; ++ ++static const struct v4l2_file_operations v4l2_loopback_fops = { ++ // clang-format off ++ .owner = THIS_MODULE, ++ .open = v4l2_loopback_open, ++ .release = v4l2_loopback_close, ++ .read = v4l2_loopback_read, ++ .write = v4l2_loopback_write, ++ .poll = v4l2_loopback_poll, ++ .mmap = v4l2_loopback_mmap, ++ .unlocked_ioctl = video_ioctl2, ++ // clang-format on ++}; ++ ++static const struct v4l2_ioctl_ops v4l2_loopback_ioctl_ops = { ++ // clang-format off ++ .vidioc_querycap = &vidioc_querycap, ++ .vidioc_enum_framesizes = &vidioc_enum_framesizes, ++ .vidioc_enum_frameintervals = &vidioc_enum_frameintervals, ++ ++ .vidioc_enum_output = &vidioc_enum_output, ++ .vidioc_g_output = &vidioc_g_output, ++ .vidioc_s_output = &vidioc_s_output, ++ ++ .vidioc_enum_input = &vidioc_enum_input, ++ .vidioc_g_input = &vidioc_g_input, ++ .vidioc_s_input = &vidioc_s_input, ++ ++ .vidioc_enum_fmt_vid_cap = &vidioc_enum_fmt_cap, ++ .vidioc_g_fmt_vid_cap = &vidioc_g_fmt_cap, ++ .vidioc_s_fmt_vid_cap = &vidioc_s_fmt_cap, ++ .vidioc_try_fmt_vid_cap = &vidioc_try_fmt_cap, ++ ++ .vidioc_enum_fmt_vid_out = &vidioc_enum_fmt_out, ++ .vidioc_s_fmt_vid_out = &vidioc_s_fmt_out, ++ .vidioc_g_fmt_vid_out = &vidioc_g_fmt_out, ++ .vidioc_try_fmt_vid_out = &vidioc_try_fmt_out, ++ ++#ifdef V4L2L_OVERLAY ++ .vidioc_s_fmt_vid_overlay = &vidioc_s_fmt_overlay, ++ .vidioc_g_fmt_vid_overlay = &vidioc_g_fmt_overlay, ++#endif ++ ++#ifdef V4L2LOOPBACK_WITH_STD ++ .vidioc_s_std = &vidioc_s_std, ++ .vidioc_g_std = &vidioc_g_std, ++ .vidioc_querystd = &vidioc_querystd, ++#endif /* V4L2LOOPBACK_WITH_STD */ ++ ++ .vidioc_g_parm = &vidioc_g_parm, ++ .vidioc_s_parm = &vidioc_s_parm, ++ ++ .vidioc_reqbufs = &vidioc_reqbufs, ++ .vidioc_querybuf = &vidioc_querybuf, ++ .vidioc_qbuf = &vidioc_qbuf, ++ .vidioc_dqbuf = &vidioc_dqbuf, ++ ++ .vidioc_streamon = &vidioc_streamon, ++ .vidioc_streamoff = &vidioc_streamoff, ++ ++#ifdef CONFIG_VIDEO_V4L1_COMPAT ++ .vidiocgmbuf = &vidiocgmbuf, ++#endif ++ ++ .vidioc_subscribe_event = &vidioc_subscribe_event, ++ .vidioc_unsubscribe_event = &v4l2_event_unsubscribe, ++ // clang-format on ++}; ++ ++static int free_device_cb(int id, void *ptr, void *data) ++{ ++ struct v4l2_loopback_device *dev = ptr; ++ v4l2_loopback_remove(dev); ++ return 0; ++} ++static void free_devices(void) ++{ ++ idr_for_each(&v4l2loopback_index_idr, &free_device_cb, NULL); ++ idr_destroy(&v4l2loopback_index_idr); ++} ++ ++static int __init v4l2loopback_init_module(void) ++{ ++ const u32 min_width = V4L2LOOPBACK_SIZE_MIN_WIDTH; ++ const u32 min_height = V4L2LOOPBACK_SIZE_MIN_HEIGHT; ++ int err; ++ int i; ++ MARK(); ++ ++ err = misc_register(&v4l2loopback_misc); ++ if (err < 0) ++ return err; ++ ++ if (devices < 0) { ++ devices = 1; ++ ++ /* try guessing the devices from the "video_nr" parameter */ ++ for (i = MAX_DEVICES - 1; i >= 0; i--) { ++ if (video_nr[i] >= 0) { ++ devices = i + 1; ++ break; ++ } ++ } ++ } ++ ++ if (devices > MAX_DEVICES) { ++ devices = MAX_DEVICES; ++ printk(KERN_INFO ++ "v4l2loopback: number of initial devices is limited to: %d\n", ++ MAX_DEVICES); ++ } ++ ++ if (max_buffers > MAX_BUFFERS) { ++ max_buffers = MAX_BUFFERS; ++ printk(KERN_INFO ++ "v4l2loopback: number of buffers is limited to: %d\n", ++ MAX_BUFFERS); ++ } ++ ++ if (max_openers < 0) { ++ printk(KERN_INFO ++ "v4l2loopback: allowing %d openers rather than %d\n", ++ 2, max_openers); ++ max_openers = 2; ++ } ++ ++ if (max_width < min_width) { ++ max_width = V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH; ++ printk(KERN_INFO "v4l2loopback: using max_width %d\n", ++ max_width); ++ } ++ if (max_height < min_height) { ++ max_height = V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT; ++ printk(KERN_INFO "v4l2loopback: using max_height %d\n", ++ max_height); ++ } ++ ++ for (i = 0; i < devices; i++) { ++ struct v4l2_loopback_config cfg = { ++ // clang-format off ++ .output_nr = video_nr[i], ++#ifdef SPLIT_DEVICES ++ .capture_nr = video_nr[i], ++#endif ++ .min_width = min_width, ++ .min_height = min_height, ++ .max_width = max_width, ++ .max_height = max_height, ++ .announce_all_caps = (!exclusive_caps[i]), ++ .max_buffers = max_buffers, ++ .max_openers = max_openers, ++ .debug = debug, ++ // clang-format on ++ }; ++ cfg.card_label[0] = 0; ++ if (card_label[i]) ++ snprintf(cfg.card_label, sizeof(cfg.card_label), "%s", ++ card_label[i]); ++ err = v4l2_loopback_add(&cfg, 0); ++ if (err) { ++ free_devices(); ++ goto error; ++ } ++ } ++ ++ dprintk("module installed\n"); ++ ++ printk(KERN_INFO "v4l2loopback driver version %d.%d.%d%s loaded\n", ++ // clang-format off ++ (V4L2LOOPBACK_VERSION_CODE >> 16) & 0xff, ++ (V4L2LOOPBACK_VERSION_CODE >> 8) & 0xff, ++ (V4L2LOOPBACK_VERSION_CODE ) & 0xff, ++#ifdef SNAPSHOT_VERSION ++ " (" __stringify(SNAPSHOT_VERSION) ")" ++#else ++ "" ++#endif ++ ); ++ // clang-format on ++ ++ return 0; ++error: ++ misc_deregister(&v4l2loopback_misc); ++ return err; ++} ++ ++static void v4l2loopback_cleanup_module(void) ++{ ++ MARK(); ++ /* unregister the device -> it deletes /dev/video* */ ++ free_devices(); ++ /* and get rid of /dev/v4l2loopback */ ++ misc_deregister(&v4l2loopback_misc); ++ dprintk("module removed\n"); ++} ++ ++MODULE_ALIAS_MISCDEV(MISC_DYNAMIC_MINOR); ++ ++module_init(v4l2loopback_init_module); ++module_exit(v4l2loopback_cleanup_module); +diff --git a/drivers/media/v4l2-core/v4l2loopback.h b/drivers/media/v4l2-core/v4l2loopback.h +new file mode 100644 +index 000000000000..1bc7e6b747a4 +--- /dev/null ++++ b/drivers/media/v4l2-core/v4l2loopback.h +@@ -0,0 +1,98 @@ ++/* SPDX-License-Identifier: GPL-2.0+ WITH Linux-syscall-note */ ++/* ++ * v4l2loopback.h ++ * ++ * Written by IOhannes m zmölnig, 7/1/20. ++ * ++ * Copyright 2020 by IOhannes m zmölnig. Redistribution of this file is ++ * permitted under the GNU General Public License. ++ */ ++#ifndef _V4L2LOOPBACK_H ++#define _V4L2LOOPBACK_H ++ ++#define V4L2LOOPBACK_VERSION_MAJOR 0 ++#define V4L2LOOPBACK_VERSION_MINOR 13 ++#define V4L2LOOPBACK_VERSION_BUGFIX 1 ++ ++/* /dev/v4l2loopback interface */ ++ ++struct v4l2_loopback_config { ++ /** ++ * the device-number (/dev/video) ++ * V4L2LOOPBACK_CTL_ADD: ++ * setting this to a value<0, will allocate an available one ++ * if nr>=0 and the device already exists, the ioctl will EEXIST ++ * if output_nr and capture_nr are the same, only a single device will be created ++ * NOTE: currently split-devices (where output_nr and capture_nr differ) ++ * are not implemented yet. ++ * until then, requesting different device-IDs will result in EINVAL. ++ * ++ * V4L2LOOPBACK_CTL_QUERY: ++ * either both output_nr and capture_nr must refer to the same loopback, ++ * or one (and only one) of them must be -1 ++ * ++ */ ++ int output_nr; ++ int unused; /*capture_nr;*/ ++ ++ /** ++ * a nice name for your device ++ * if (*card_label)==0, an automatic name is assigned ++ */ ++ char card_label[32]; ++ ++ /** ++ * allowed frame size ++ * if too low, default values are used ++ */ ++ unsigned int min_width; ++ unsigned int max_width; ++ unsigned int min_height; ++ unsigned int max_height; ++ ++ /** ++ * number of buffers to allocate for the queue ++ * if set to <=0, default values are used ++ */ ++ int max_buffers; ++ ++ /** ++ * how many consumers are allowed to open this device concurrently ++ * if set to <=0, default values are used ++ */ ++ int max_openers; ++ ++ /** ++ * set the debugging level for this device ++ */ ++ int debug; ++ ++ /** ++ * whether to announce OUTPUT/CAPTURE capabilities exclusively ++ * for this device or not ++ * (!exclusive_caps) ++ * NOTE: this is going to be removed once separate output/capture ++ * devices are implemented ++ */ ++ int announce_all_caps; ++}; ++ ++/* a pointer to a (struct v4l2_loopback_config) that has all values you wish to impose on the ++ * to-be-created device set. ++ * if the ptr is NULL, a new device is created with default values at the driver's discretion. ++ * ++ * returns the device_nr of the OUTPUT device (which can be used with V4L2LOOPBACK_CTL_QUERY, ++ * to get more information on the device) ++ */ ++#define V4L2LOOPBACK_CTL_ADD 0x4C80 ++ ++/* a pointer to a (struct v4l2_loopback_config) that has output_nr and/or capture_nr set ++ * (the two values must either refer to video-devices associated with the same loopback device ++ * or exactly one of them must be <0 ++ */ ++#define V4L2LOOPBACK_CTL_QUERY 0x4C82 ++ ++/* the device-number (either CAPTURE or OUTPUT) associated with the loopback-device */ ++#define V4L2LOOPBACK_CTL_REMOVE 0x4C81 ++ ++#endif /* _V4L2LOOPBACK_H */ +diff --git a/drivers/media/v4l2-core/v4l2loopback_formats.h b/drivers/media/v4l2-core/v4l2loopback_formats.h +new file mode 100644 +index 000000000000..d855a3796554 +--- /dev/null ++++ b/drivers/media/v4l2-core/v4l2loopback_formats.h +@@ -0,0 +1,445 @@ ++static const struct v4l2l_format formats[] = { ++#ifndef V4L2_PIX_FMT_VP9 ++#define V4L2_PIX_FMT_VP9 v4l2_fourcc('V', 'P', '9', '0') ++#endif ++#ifndef V4L2_PIX_FMT_HEVC ++#define V4L2_PIX_FMT_HEVC v4l2_fourcc('H', 'E', 'V', 'C') ++#endif ++ ++ /* here come the packed formats */ ++ { ++ .name = "32 bpp RGB, le", ++ .fourcc = V4L2_PIX_FMT_BGR32, ++ .depth = 32, ++ .flags = 0, ++ }, ++ { ++ .name = "32 bpp RGB, be", ++ .fourcc = V4L2_PIX_FMT_RGB32, ++ .depth = 32, ++ .flags = 0, ++ }, ++ { ++ .name = "24 bpp RGB, le", ++ .fourcc = V4L2_PIX_FMT_BGR24, ++ .depth = 24, ++ .flags = 0, ++ }, ++ { ++ .name = "24 bpp RGB, be", ++ .fourcc = V4L2_PIX_FMT_RGB24, ++ .depth = 24, ++ .flags = 0, ++ }, ++#ifdef V4L2_PIX_FMT_ABGR32 ++ { ++ .name = "32 bpp RGBA, le", ++ .fourcc = V4L2_PIX_FMT_ABGR32, ++ .depth = 32, ++ .flags = 0, ++ }, ++#endif ++#ifdef V4L2_PIX_FMT_RGBA32 ++ { ++ .name = "32 bpp RGBA", ++ .fourcc = V4L2_PIX_FMT_RGBA32, ++ .depth = 32, ++ .flags = 0, ++ }, ++#endif ++#ifdef V4L2_PIX_FMT_RGB332 ++ { ++ .name = "8 bpp RGB-3-3-2", ++ .fourcc = V4L2_PIX_FMT_RGB332, ++ .depth = 8, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_RGB332 */ ++#ifdef V4L2_PIX_FMT_RGB444 ++ { ++ .name = "16 bpp RGB (xxxxrrrr ggggbbbb)", ++ .fourcc = V4L2_PIX_FMT_RGB444, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_RGB444 */ ++#ifdef V4L2_PIX_FMT_RGB555 ++ { ++ .name = "16 bpp RGB-5-5-5", ++ .fourcc = V4L2_PIX_FMT_RGB555, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_RGB555 */ ++#ifdef V4L2_PIX_FMT_RGB565 ++ { ++ .name = "16 bpp RGB-5-6-5", ++ .fourcc = V4L2_PIX_FMT_RGB565, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_RGB565 */ ++#ifdef V4L2_PIX_FMT_RGB555X ++ { ++ .name = "16 bpp RGB-5-5-5 BE", ++ .fourcc = V4L2_PIX_FMT_RGB555X, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_RGB555X */ ++#ifdef V4L2_PIX_FMT_RGB565X ++ { ++ .name = "16 bpp RGB-5-6-5 BE", ++ .fourcc = V4L2_PIX_FMT_RGB565X, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_RGB565X */ ++#ifdef V4L2_PIX_FMT_BGR666 ++ { ++ .name = "18 bpp BGR-6-6-6", ++ .fourcc = V4L2_PIX_FMT_BGR666, ++ .depth = 18, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_BGR666 */ ++ { ++ .name = "4:2:2, packed, YUYV", ++ .fourcc = V4L2_PIX_FMT_YUYV, ++ .depth = 16, ++ .flags = 0, ++ }, ++ { ++ .name = "4:2:2, packed, UYVY", ++ .fourcc = V4L2_PIX_FMT_UYVY, ++ .depth = 16, ++ .flags = 0, ++ }, ++#ifdef V4L2_PIX_FMT_YVYU ++ { ++ .name = "4:2:2, packed YVYU", ++ .fourcc = V4L2_PIX_FMT_YVYU, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif ++#ifdef V4L2_PIX_FMT_VYUY ++ { ++ .name = "4:2:2, packed VYUY", ++ .fourcc = V4L2_PIX_FMT_VYUY, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif ++ { ++ .name = "4:2:2, packed YYUV", ++ .fourcc = V4L2_PIX_FMT_YYUV, ++ .depth = 16, ++ .flags = 0, ++ }, ++ { ++ .name = "YUV-8-8-8-8", ++ .fourcc = V4L2_PIX_FMT_YUV32, ++ .depth = 32, ++ .flags = 0, ++ }, ++ { ++ .name = "8 bpp, Greyscale", ++ .fourcc = V4L2_PIX_FMT_GREY, ++ .depth = 8, ++ .flags = 0, ++ }, ++#ifdef V4L2_PIX_FMT_Y4 ++ { ++ .name = "4 bpp Greyscale", ++ .fourcc = V4L2_PIX_FMT_Y4, ++ .depth = 4, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_Y4 */ ++#ifdef V4L2_PIX_FMT_Y6 ++ { ++ .name = "6 bpp Greyscale", ++ .fourcc = V4L2_PIX_FMT_Y6, ++ .depth = 6, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_Y6 */ ++#ifdef V4L2_PIX_FMT_Y10 ++ { ++ .name = "10 bpp Greyscale", ++ .fourcc = V4L2_PIX_FMT_Y10, ++ .depth = 10, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_Y10 */ ++#ifdef V4L2_PIX_FMT_Y12 ++ { ++ .name = "12 bpp Greyscale", ++ .fourcc = V4L2_PIX_FMT_Y12, ++ .depth = 12, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_Y12 */ ++ { ++ .name = "16 bpp, Greyscale", ++ .fourcc = V4L2_PIX_FMT_Y16, ++ .depth = 16, ++ .flags = 0, ++ }, ++#ifdef V4L2_PIX_FMT_YUV444 ++ { ++ .name = "16 bpp xxxxyyyy uuuuvvvv", ++ .fourcc = V4L2_PIX_FMT_YUV444, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_YUV444 */ ++#ifdef V4L2_PIX_FMT_YUV555 ++ { ++ .name = "16 bpp YUV-5-5-5", ++ .fourcc = V4L2_PIX_FMT_YUV555, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_YUV555 */ ++#ifdef V4L2_PIX_FMT_YUV565 ++ { ++ .name = "16 bpp YUV-5-6-5", ++ .fourcc = V4L2_PIX_FMT_YUV565, ++ .depth = 16, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_YUV565 */ ++ ++/* bayer formats */ ++#ifdef V4L2_PIX_FMT_SRGGB8 ++ { ++ .name = "Bayer RGGB 8bit", ++ .fourcc = V4L2_PIX_FMT_SRGGB8, ++ .depth = 8, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_SRGGB8 */ ++#ifdef V4L2_PIX_FMT_SGRBG8 ++ { ++ .name = "Bayer GRBG 8bit", ++ .fourcc = V4L2_PIX_FMT_SGRBG8, ++ .depth = 8, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_SGRBG8 */ ++#ifdef V4L2_PIX_FMT_SGBRG8 ++ { ++ .name = "Bayer GBRG 8bit", ++ .fourcc = V4L2_PIX_FMT_SGBRG8, ++ .depth = 8, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_SGBRG8 */ ++#ifdef V4L2_PIX_FMT_SBGGR8 ++ { ++ .name = "Bayer BA81 8bit", ++ .fourcc = V4L2_PIX_FMT_SBGGR8, ++ .depth = 8, ++ .flags = 0, ++ }, ++#endif /* V4L2_PIX_FMT_SBGGR8 */ ++ ++ /* here come the planar formats */ ++ { ++ .name = "4:1:0, planar, Y-Cr-Cb", ++ .fourcc = V4L2_PIX_FMT_YVU410, ++ .depth = 9, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++ { ++ .name = "4:2:0, planar, Y-Cr-Cb", ++ .fourcc = V4L2_PIX_FMT_YVU420, ++ .depth = 12, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++ { ++ .name = "4:1:0, planar, Y-Cb-Cr", ++ .fourcc = V4L2_PIX_FMT_YUV410, ++ .depth = 9, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++ { ++ .name = "4:2:0, planar, Y-Cb-Cr", ++ .fourcc = V4L2_PIX_FMT_YUV420, ++ .depth = 12, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++#ifdef V4L2_PIX_FMT_YUV422P ++ { ++ .name = "16 bpp YVU422 planar", ++ .fourcc = V4L2_PIX_FMT_YUV422P, ++ .depth = 16, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++#endif /* V4L2_PIX_FMT_YUV422P */ ++#ifdef V4L2_PIX_FMT_YUV411P ++ { ++ .name = "16 bpp YVU411 planar", ++ .fourcc = V4L2_PIX_FMT_YUV411P, ++ .depth = 16, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++#endif /* V4L2_PIX_FMT_YUV411P */ ++#ifdef V4L2_PIX_FMT_Y41P ++ { ++ .name = "12 bpp YUV 4:1:1", ++ .fourcc = V4L2_PIX_FMT_Y41P, ++ .depth = 12, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++#endif /* V4L2_PIX_FMT_Y41P */ ++#ifdef V4L2_PIX_FMT_NV12 ++ { ++ .name = "12 bpp Y/CbCr 4:2:0 ", ++ .fourcc = V4L2_PIX_FMT_NV12, ++ .depth = 12, ++ .flags = FORMAT_FLAGS_PLANAR, ++ }, ++#endif /* V4L2_PIX_FMT_NV12 */ ++ ++/* here come the compressed formats */ ++ ++#ifdef V4L2_PIX_FMT_MJPEG ++ { ++ .name = "Motion-JPEG", ++ .fourcc = V4L2_PIX_FMT_MJPEG, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_MJPEG */ ++#ifdef V4L2_PIX_FMT_JPEG ++ { ++ .name = "JFIF JPEG", ++ .fourcc = V4L2_PIX_FMT_JPEG, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_JPEG */ ++#ifdef V4L2_PIX_FMT_DV ++ { ++ .name = "DV1394", ++ .fourcc = V4L2_PIX_FMT_DV, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_DV */ ++#ifdef V4L2_PIX_FMT_MPEG ++ { ++ .name = "MPEG-1/2/4 Multiplexed", ++ .fourcc = V4L2_PIX_FMT_MPEG, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_MPEG */ ++#ifdef V4L2_PIX_FMT_H264 ++ { ++ .name = "H264 with start codes", ++ .fourcc = V4L2_PIX_FMT_H264, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_H264 */ ++#ifdef V4L2_PIX_FMT_H264_NO_SC ++ { ++ .name = "H264 without start codes", ++ .fourcc = V4L2_PIX_FMT_H264_NO_SC, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_H264_NO_SC */ ++#ifdef V4L2_PIX_FMT_H264_MVC ++ { ++ .name = "H264 MVC", ++ .fourcc = V4L2_PIX_FMT_H264_MVC, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_H264_MVC */ ++#ifdef V4L2_PIX_FMT_H263 ++ { ++ .name = "H263", ++ .fourcc = V4L2_PIX_FMT_H263, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_H263 */ ++#ifdef V4L2_PIX_FMT_MPEG1 ++ { ++ .name = "MPEG-1 ES", ++ .fourcc = V4L2_PIX_FMT_MPEG1, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_MPEG1 */ ++#ifdef V4L2_PIX_FMT_MPEG2 ++ { ++ .name = "MPEG-2 ES", ++ .fourcc = V4L2_PIX_FMT_MPEG2, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_MPEG2 */ ++#ifdef V4L2_PIX_FMT_MPEG4 ++ { ++ .name = "MPEG-4 part 2 ES", ++ .fourcc = V4L2_PIX_FMT_MPEG4, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_MPEG4 */ ++#ifdef V4L2_PIX_FMT_XVID ++ { ++ .name = "Xvid", ++ .fourcc = V4L2_PIX_FMT_XVID, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_XVID */ ++#ifdef V4L2_PIX_FMT_VC1_ANNEX_G ++ { ++ .name = "SMPTE 421M Annex G compliant stream", ++ .fourcc = V4L2_PIX_FMT_VC1_ANNEX_G, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_VC1_ANNEX_G */ ++#ifdef V4L2_PIX_FMT_VC1_ANNEX_L ++ { ++ .name = "SMPTE 421M Annex L compliant stream", ++ .fourcc = V4L2_PIX_FMT_VC1_ANNEX_L, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_VC1_ANNEX_L */ ++#ifdef V4L2_PIX_FMT_VP8 ++ { ++ .name = "VP8", ++ .fourcc = V4L2_PIX_FMT_VP8, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_VP8 */ ++#ifdef V4L2_PIX_FMT_VP9 ++ { ++ .name = "VP9", ++ .fourcc = V4L2_PIX_FMT_VP9, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_VP9 */ ++#ifdef V4L2_PIX_FMT_HEVC ++ { ++ .name = "HEVC", ++ .fourcc = V4L2_PIX_FMT_HEVC, ++ .depth = 32, ++ .flags = FORMAT_FLAGS_COMPRESSED, ++ }, ++#endif /* V4L2_PIX_FMT_HEVC */ ++}; +diff --git a/drivers/pci/controller/Makefile b/drivers/pci/controller/Makefile +index f2b19e6174af..4fef4b174321 100644 +--- a/drivers/pci/controller/Makefile ++++ b/drivers/pci/controller/Makefile +@@ -1,4 +1,10 @@ + # SPDX-License-Identifier: GPL-2.0 ++ifdef CONFIG_X86_64 ++ifdef CONFIG_SATA_AHCI ++obj-y += intel-nvme-remap.o ++endif ++endif ++ + obj-$(CONFIG_PCIE_CADENCE) += cadence/ + obj-$(CONFIG_PCI_FTPCI100) += pci-ftpci100.o + obj-$(CONFIG_PCI_IXP4XX) += pci-ixp4xx.o +diff --git a/drivers/pci/controller/intel-nvme-remap.c b/drivers/pci/controller/intel-nvme-remap.c +new file mode 100644 +index 000000000000..e105e6f5cc91 +--- /dev/null ++++ b/drivers/pci/controller/intel-nvme-remap.c +@@ -0,0 +1,462 @@ ++// SPDX-License-Identifier: GPL-2.0 ++/* ++ * Intel remapped NVMe device support. ++ * ++ * Copyright (c) 2019 Endless Mobile, Inc. ++ * Author: Daniel Drake ++ * ++ * Some products ship by default with the SATA controller in "RAID" or ++ * "Intel RST Premium With Intel Optane System Acceleration" mode. Under this ++ * mode, which we refer to as "remapped NVMe" mode, any installed NVMe ++ * devices disappear from the PCI bus, and instead their I/O memory becomes ++ * available within the AHCI device BARs. ++ * ++ * This scheme is understood to be a way of avoiding usage of the standard ++ * Windows NVMe driver under that OS, instead mandating usage of Intel's ++ * driver instead, which has better power management, and presumably offers ++ * some RAID/disk-caching solutions too. ++ * ++ * Here in this driver, we support the remapped NVMe mode by claiming the ++ * AHCI device and creating a fake PCIe root port. On the new bus, the ++ * original AHCI device is exposed with only minor tweaks. Then, fake PCI ++ * devices corresponding to the remapped NVMe devices are created. The usual ++ * ahci and nvme drivers are then expected to bind to these devices and ++ * operate as normal. ++ * ++ * The PCI configuration space for the NVMe devices is completely ++ * unavailable, so we fake a minimal one and hope for the best. ++ * ++ * Interrupts are shared between the AHCI and NVMe devices. For simplicity, ++ * we only support the legacy interrupt here, although MSI support ++ * could potentially be added later. ++ */ ++ ++#define MODULE_NAME "intel-nvme-remap" ++ ++#include ++#include ++#include ++#include ++#include ++ ++#define AHCI_PCI_BAR_STANDARD 5 ++ ++struct nvme_remap_dev { ++ struct pci_dev *dev; /* AHCI device */ ++ struct pci_bus *bus; /* our fake PCI bus */ ++ struct pci_sysdata sysdata; ++ int irq_base; /* our fake interrupts */ ++ ++ /* ++ * When we detect an all-ones write to a BAR register, this flag ++ * is set, so that we return the BAR size on the next read (a ++ * standard PCI behaviour). ++ * This includes the assumption that an all-ones BAR write is ++ * immediately followed by a read of the same register. ++ */ ++ bool bar_sizing; ++ ++ /* ++ * Resources copied from the AHCI device, to be regarded as ++ * resources on our fake bus. ++ */ ++ struct resource ahci_resources[PCI_NUM_RESOURCES]; ++ ++ /* Resources corresponding to the NVMe devices. */ ++ struct resource remapped_dev_mem[AHCI_MAX_REMAP]; ++ ++ /* Number of remapped NVMe devices found. */ ++ int num_remapped_devices; ++}; ++ ++static inline struct nvme_remap_dev *nrdev_from_bus(struct pci_bus *bus) ++{ ++ return container_of(bus->sysdata, struct nvme_remap_dev, sysdata); ++} ++ ++ ++/******** PCI configuration space **********/ ++ ++/* ++ * Helper macros for tweaking returned contents of PCI configuration space. ++ * ++ * value contains len bytes of data read from reg. ++ * If fixup_reg is included in that range, fix up the contents of that ++ * register to fixed_value. ++ */ ++#define NR_FIX8(fixup_reg, fixed_value) do { \ ++ if (reg <= fixup_reg && fixup_reg < reg + len) \ ++ ((u8 *) value)[fixup_reg - reg] = (u8) (fixed_value); \ ++ } while (0) ++ ++#define NR_FIX16(fixup_reg, fixed_value) do { \ ++ NR_FIX8(fixup_reg, fixed_value); \ ++ NR_FIX8(fixup_reg + 1, fixed_value >> 8); \ ++ } while (0) ++ ++#define NR_FIX24(fixup_reg, fixed_value) do { \ ++ NR_FIX8(fixup_reg, fixed_value); \ ++ NR_FIX8(fixup_reg + 1, fixed_value >> 8); \ ++ NR_FIX8(fixup_reg + 2, fixed_value >> 16); \ ++ } while (0) ++ ++#define NR_FIX32(fixup_reg, fixed_value) do { \ ++ NR_FIX16(fixup_reg, (u16) fixed_value); \ ++ NR_FIX16(fixup_reg + 2, fixed_value >> 16); \ ++ } while (0) ++ ++/* ++ * Read PCI config space of the slot 0 (AHCI) device. ++ * We pass through the read request to the underlying device, but ++ * tweak the results in some cases. ++ */ ++static int nvme_remap_pci_read_slot0(struct pci_bus *bus, int reg, ++ int len, u32 *value) ++{ ++ struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); ++ struct pci_bus *ahci_dev_bus = nrdev->dev->bus; ++ int ret; ++ ++ ret = ahci_dev_bus->ops->read(ahci_dev_bus, nrdev->dev->devfn, ++ reg, len, value); ++ if (ret) ++ return ret; ++ ++ /* ++ * Adjust the device class, to prevent this driver from attempting to ++ * additionally probe the device we're simulating here. ++ */ ++ NR_FIX24(PCI_CLASS_PROG, PCI_CLASS_STORAGE_SATA_AHCI); ++ ++ /* ++ * Unset interrupt pin, otherwise ACPI tries to find routing ++ * info for our virtual IRQ, fails, and complains. ++ */ ++ NR_FIX8(PCI_INTERRUPT_PIN, 0); ++ ++ /* ++ * Truncate the AHCI BAR to not include the region that covers the ++ * hidden devices. This will cause the ahci driver to successfully ++ * probe th new device (instead of handing it over to this driver). ++ */ ++ if (nrdev->bar_sizing) { ++ NR_FIX32(PCI_BASE_ADDRESS_5, ~(SZ_16K - 1)); ++ nrdev->bar_sizing = false; ++ } ++ ++ return PCIBIOS_SUCCESSFUL; ++} ++ ++/* ++ * Read PCI config space of a remapped device. ++ * Since the original PCI config space is inaccessible, we provide a minimal, ++ * fake config space instead. ++ */ ++static int nvme_remap_pci_read_remapped(struct pci_bus *bus, unsigned int port, ++ int reg, int len, u32 *value) ++{ ++ struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); ++ struct resource *remapped_mem; ++ ++ if (port > nrdev->num_remapped_devices) ++ return PCIBIOS_DEVICE_NOT_FOUND; ++ ++ *value = 0; ++ remapped_mem = &nrdev->remapped_dev_mem[port - 1]; ++ ++ /* Set a Vendor ID, otherwise Linux assumes no device is present */ ++ NR_FIX16(PCI_VENDOR_ID, PCI_VENDOR_ID_INTEL); ++ ++ /* Always appear on & bus mastering */ ++ NR_FIX16(PCI_COMMAND, PCI_COMMAND_MEMORY | PCI_COMMAND_MASTER); ++ ++ /* Set class so that nvme driver probes us */ ++ NR_FIX24(PCI_CLASS_PROG, PCI_CLASS_STORAGE_EXPRESS); ++ ++ if (nrdev->bar_sizing) { ++ NR_FIX32(PCI_BASE_ADDRESS_0, ++ ~(resource_size(remapped_mem) - 1)); ++ nrdev->bar_sizing = false; ++ } else { ++ resource_size_t mem_start = remapped_mem->start; ++ ++ mem_start |= PCI_BASE_ADDRESS_MEM_TYPE_64; ++ NR_FIX32(PCI_BASE_ADDRESS_0, mem_start); ++ mem_start >>= 32; ++ NR_FIX32(PCI_BASE_ADDRESS_1, mem_start); ++ } ++ ++ return PCIBIOS_SUCCESSFUL; ++} ++ ++/* Read PCI configuration space. */ ++static int nvme_remap_pci_read(struct pci_bus *bus, unsigned int devfn, ++ int reg, int len, u32 *value) ++{ ++ if (PCI_SLOT(devfn) == 0) ++ return nvme_remap_pci_read_slot0(bus, reg, len, value); ++ else ++ return nvme_remap_pci_read_remapped(bus, PCI_SLOT(devfn), ++ reg, len, value); ++} ++ ++/* ++ * Write PCI config space of the slot 0 (AHCI) device. ++ * Apart from the special case of BAR sizing, we disable all writes. ++ * Otherwise, the ahci driver could make changes (e.g. unset PCI bus master) ++ * that would affect the operation of the NVMe devices. ++ */ ++static int nvme_remap_pci_write_slot0(struct pci_bus *bus, int reg, ++ int len, u32 value) ++{ ++ struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); ++ struct pci_bus *ahci_dev_bus = nrdev->dev->bus; ++ ++ if (reg >= PCI_BASE_ADDRESS_0 && reg <= PCI_BASE_ADDRESS_5) { ++ /* ++ * Writing all-ones to a BAR means that the size of the ++ * memory region is being checked. Flag this so that we can ++ * reply with an appropriate size on the next read. ++ */ ++ if (value == ~0) ++ nrdev->bar_sizing = true; ++ ++ return ahci_dev_bus->ops->write(ahci_dev_bus, ++ nrdev->dev->devfn, ++ reg, len, value); ++ } ++ ++ return PCIBIOS_SET_FAILED; ++} ++ ++/* ++ * Write PCI config space of a remapped device. ++ * Since the original PCI config space is inaccessible, we reject all ++ * writes, except for the special case of BAR probing. ++ */ ++static int nvme_remap_pci_write_remapped(struct pci_bus *bus, ++ unsigned int port, ++ int reg, int len, u32 value) ++{ ++ struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); ++ ++ if (port > nrdev->num_remapped_devices) ++ return PCIBIOS_DEVICE_NOT_FOUND; ++ ++ /* ++ * Writing all-ones to a BAR means that the size of the memory ++ * region is being checked. Flag this so that we can reply with ++ * an appropriate size on the next read. ++ */ ++ if (value == ~0 && reg >= PCI_BASE_ADDRESS_0 ++ && reg <= PCI_BASE_ADDRESS_5) { ++ nrdev->bar_sizing = true; ++ return PCIBIOS_SUCCESSFUL; ++ } ++ ++ return PCIBIOS_SET_FAILED; ++} ++ ++/* Write PCI configuration space. */ ++static int nvme_remap_pci_write(struct pci_bus *bus, unsigned int devfn, ++ int reg, int len, u32 value) ++{ ++ if (PCI_SLOT(devfn) == 0) ++ return nvme_remap_pci_write_slot0(bus, reg, len, value); ++ else ++ return nvme_remap_pci_write_remapped(bus, PCI_SLOT(devfn), ++ reg, len, value); ++} ++ ++static struct pci_ops nvme_remap_pci_ops = { ++ .read = nvme_remap_pci_read, ++ .write = nvme_remap_pci_write, ++}; ++ ++ ++/******** Initialization & exit **********/ ++ ++/* ++ * Find a PCI domain ID to use for our fake bus. ++ * Start at 0x10000 to not clash with ACPI _SEG domains (16 bits). ++ */ ++static int find_free_domain(void) ++{ ++ int domain = 0xffff; ++ struct pci_bus *bus = NULL; ++ ++ while ((bus = pci_find_next_bus(bus)) != NULL) ++ domain = max_t(int, domain, pci_domain_nr(bus)); ++ ++ return domain + 1; ++} ++ ++static int find_remapped_devices(struct nvme_remap_dev *nrdev, ++ struct list_head *resources) ++{ ++ void __iomem *mmio; ++ int i, count = 0; ++ u32 cap; ++ ++ mmio = pcim_iomap(nrdev->dev, AHCI_PCI_BAR_STANDARD, ++ pci_resource_len(nrdev->dev, ++ AHCI_PCI_BAR_STANDARD)); ++ if (!mmio) ++ return -ENODEV; ++ ++ /* Check if this device might have remapped nvme devices. */ ++ if (pci_resource_len(nrdev->dev, AHCI_PCI_BAR_STANDARD) < SZ_512K || ++ !(readl(mmio + AHCI_VSCAP) & 1)) ++ return -ENODEV; ++ ++ cap = readq(mmio + AHCI_REMAP_CAP); ++ for (i = AHCI_MAX_REMAP-1; i >= 0; i--) { ++ struct resource *remapped_mem; ++ ++ if ((cap & (1 << i)) == 0) ++ continue; ++ if (readl(mmio + ahci_remap_dcc(i)) ++ != PCI_CLASS_STORAGE_EXPRESS) ++ continue; ++ ++ /* We've found a remapped device */ ++ remapped_mem = &nrdev->remapped_dev_mem[count++]; ++ remapped_mem->start = ++ pci_resource_start(nrdev->dev, AHCI_PCI_BAR_STANDARD) ++ + ahci_remap_base(i); ++ remapped_mem->end = remapped_mem->start ++ + AHCI_REMAP_N_SIZE - 1; ++ remapped_mem->flags = IORESOURCE_MEM | IORESOURCE_PCI_FIXED; ++ pci_add_resource(resources, remapped_mem); ++ } ++ ++ pcim_iounmap(nrdev->dev, mmio); ++ ++ if (count == 0) ++ return -ENODEV; ++ ++ nrdev->num_remapped_devices = count; ++ dev_info(&nrdev->dev->dev, "Found %d remapped NVMe devices\n", ++ nrdev->num_remapped_devices); ++ return 0; ++} ++ ++static void nvme_remap_remove_root_bus(void *data) ++{ ++ struct pci_bus *bus = data; ++ ++ pci_stop_root_bus(bus); ++ pci_remove_root_bus(bus); ++} ++ ++static int nvme_remap_probe(struct pci_dev *dev, ++ const struct pci_device_id *id) ++{ ++ struct nvme_remap_dev *nrdev; ++ LIST_HEAD(resources); ++ int i; ++ int ret; ++ struct pci_dev *child; ++ ++ nrdev = devm_kzalloc(&dev->dev, sizeof(*nrdev), GFP_KERNEL); ++ nrdev->sysdata.domain = find_free_domain(); ++ nrdev->sysdata.nvme_remap_dev = dev; ++ nrdev->dev = dev; ++ pci_set_drvdata(dev, nrdev); ++ ++ ret = pcim_enable_device(dev); ++ if (ret < 0) ++ return ret; ++ ++ pci_set_master(dev); ++ ++ ret = find_remapped_devices(nrdev, &resources); ++ if (ret) ++ return ret; ++ ++ /* Add resources from the original AHCI device */ ++ for (i = 0; i < PCI_NUM_RESOURCES; i++) { ++ struct resource *res = &dev->resource[i]; ++ ++ if (res->start) { ++ struct resource *nr_res = &nrdev->ahci_resources[i]; ++ ++ nr_res->start = res->start; ++ nr_res->end = res->end; ++ nr_res->flags = res->flags; ++ pci_add_resource(&resources, nr_res); ++ } ++ } ++ ++ /* Create virtual interrupts */ ++ nrdev->irq_base = devm_irq_alloc_descs(&dev->dev, -1, 0, ++ nrdev->num_remapped_devices + 1, ++ 0); ++ if (nrdev->irq_base < 0) ++ return nrdev->irq_base; ++ ++ /* Create and populate PCI bus */ ++ nrdev->bus = pci_create_root_bus(&dev->dev, 0, &nvme_remap_pci_ops, ++ &nrdev->sysdata, &resources); ++ if (!nrdev->bus) ++ return -ENODEV; ++ ++ if (devm_add_action_or_reset(&dev->dev, nvme_remap_remove_root_bus, ++ nrdev->bus)) ++ return -ENOMEM; ++ ++ /* We don't support sharing MSI interrupts between these devices */ ++ nrdev->bus->bus_flags |= PCI_BUS_FLAGS_NO_MSI; ++ ++ pci_scan_child_bus(nrdev->bus); ++ ++ list_for_each_entry(child, &nrdev->bus->devices, bus_list) { ++ /* ++ * Prevent PCI core from trying to move memory BARs around. ++ * The hidden NVMe devices are at fixed locations. ++ */ ++ for (i = 0; i < PCI_NUM_RESOURCES; i++) { ++ struct resource *res = &child->resource[i]; ++ ++ if (res->flags & IORESOURCE_MEM) ++ res->flags |= IORESOURCE_PCI_FIXED; ++ } ++ ++ /* Share the legacy IRQ between all devices */ ++ child->irq = dev->irq; ++ } ++ ++ pci_assign_unassigned_bus_resources(nrdev->bus); ++ pci_bus_add_devices(nrdev->bus); ++ ++ return 0; ++} ++ ++static const struct pci_device_id nvme_remap_ids[] = { ++ /* ++ * Match all Intel RAID controllers. ++ * ++ * There's overlap here with the set of devices detected by the ahci ++ * driver, but ahci will only successfully probe when there ++ * *aren't* any remapped NVMe devices, and this driver will only ++ * successfully probe when there *are* remapped NVMe devices that ++ * need handling. ++ */ ++ { ++ PCI_VDEVICE(INTEL, PCI_ANY_ID), ++ .class = PCI_CLASS_STORAGE_RAID << 8, ++ .class_mask = 0xffffff00, ++ }, ++ {0,} ++}; ++MODULE_DEVICE_TABLE(pci, nvme_remap_ids); ++ ++static struct pci_driver nvme_remap_drv = { ++ .name = MODULE_NAME, ++ .id_table = nvme_remap_ids, ++ .probe = nvme_remap_probe, ++}; ++module_pci_driver(nvme_remap_drv); ++ ++MODULE_AUTHOR("Daniel Drake "); ++MODULE_LICENSE("GPL v2"); +diff --git a/drivers/pci/quirks.c b/drivers/pci/quirks.c +index 568410e64ce6..192d0557fb05 100644 +--- a/drivers/pci/quirks.c ++++ b/drivers/pci/quirks.c +@@ -3732,6 +3732,106 @@ static void quirk_no_bus_reset(struct pci_dev *dev) + dev->dev_flags |= PCI_DEV_FLAGS_NO_BUS_RESET; + } + ++static bool acs_on_downstream; ++static bool acs_on_multifunction; ++ ++#define NUM_ACS_IDS 16 ++struct acs_on_id { ++ unsigned short vendor; ++ unsigned short device; ++}; ++static struct acs_on_id acs_on_ids[NUM_ACS_IDS]; ++static u8 max_acs_id; ++ ++static __init int pcie_acs_override_setup(char *p) ++{ ++ if (!p) ++ return -EINVAL; ++ ++ while (*p) { ++ if (!strncmp(p, "downstream", 10)) ++ acs_on_downstream = true; ++ if (!strncmp(p, "multifunction", 13)) ++ acs_on_multifunction = true; ++ if (!strncmp(p, "id:", 3)) { ++ char opt[5]; ++ int ret; ++ long val; ++ ++ if (max_acs_id >= NUM_ACS_IDS - 1) { ++ pr_warn("Out of PCIe ACS override slots (%d)\n", ++ NUM_ACS_IDS); ++ goto next; ++ } ++ ++ p += 3; ++ snprintf(opt, 5, "%s", p); ++ ret = kstrtol(opt, 16, &val); ++ if (ret) { ++ pr_warn("PCIe ACS ID parse error %d\n", ret); ++ goto next; ++ } ++ acs_on_ids[max_acs_id].vendor = val; ++ ++ p += strcspn(p, ":"); ++ if (*p != ':') { ++ pr_warn("PCIe ACS invalid ID\n"); ++ goto next; ++ } ++ ++ p++; ++ snprintf(opt, 5, "%s", p); ++ ret = kstrtol(opt, 16, &val); ++ if (ret) { ++ pr_warn("PCIe ACS ID parse error %d\n", ret); ++ goto next; ++ } ++ acs_on_ids[max_acs_id].device = val; ++ max_acs_id++; ++ } ++next: ++ p += strcspn(p, ","); ++ if (*p == ',') ++ p++; ++ } ++ ++ if (acs_on_downstream || acs_on_multifunction || max_acs_id) ++ pr_warn("Warning: PCIe ACS overrides enabled; This may allow non-IOMMU protected peer-to-peer DMA\n"); ++ ++ return 0; ++} ++early_param("pcie_acs_override", pcie_acs_override_setup); ++ ++static int pcie_acs_overrides(struct pci_dev *dev, u16 acs_flags) ++{ ++ int i; ++ ++ /* Never override ACS for legacy devices or devices with ACS caps */ ++ if (!pci_is_pcie(dev) || ++ pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ACS)) ++ return -ENOTTY; ++ ++ for (i = 0; i < max_acs_id; i++) ++ if (acs_on_ids[i].vendor == dev->vendor && ++ acs_on_ids[i].device == dev->device) ++ return 1; ++ ++ switch (pci_pcie_type(dev)) { ++ case PCI_EXP_TYPE_DOWNSTREAM: ++ case PCI_EXP_TYPE_ROOT_PORT: ++ if (acs_on_downstream) ++ return 1; ++ break; ++ case PCI_EXP_TYPE_ENDPOINT: ++ case PCI_EXP_TYPE_UPSTREAM: ++ case PCI_EXP_TYPE_LEG_END: ++ case PCI_EXP_TYPE_RC_END: ++ if (acs_on_multifunction && dev->multifunction) ++ return 1; ++ } ++ ++ return -ENOTTY; ++} + /* + * Some NVIDIA GPU devices do not work with bus reset, SBR needs to be + * prevented for those affected devices. +@@ -5143,6 +5243,7 @@ static const struct pci_dev_acs_enabled { + { PCI_VENDOR_ID_ZHAOXIN, PCI_ANY_ID, pci_quirk_zhaoxin_pcie_ports_acs }, + /* Wangxun nics */ + { PCI_VENDOR_ID_WANGXUN, PCI_ANY_ID, pci_quirk_wangxun_nic_acs }, ++ { PCI_ANY_ID, PCI_ANY_ID, pcie_acs_overrides }, + { 0 } + }; + +diff --git a/include/linux/cpufreq.h b/include/linux/cpufreq.h +index 20f7e98ee8af..0f5aad20ced7 100644 +--- a/include/linux/cpufreq.h ++++ b/include/linux/cpufreq.h +@@ -577,12 +577,6 @@ static inline unsigned long cpufreq_scale(unsigned long old, u_int div, + #define CPUFREQ_POLICY_POWERSAVE (1) + #define CPUFREQ_POLICY_PERFORMANCE (2) + +-/* +- * The polling frequency depends on the capability of the processor. Default +- * polling frequency is 1000 times the transition latency of the processor. +- */ +-#define LATENCY_MULTIPLIER (1000) +- + struct cpufreq_governor { + char name[CPUFREQ_NAME_LEN]; + int (*init)(struct cpufreq_policy *policy); +diff --git a/include/linux/minmax.h b/include/linux/minmax.h +index 2ec559284a9f..a7ef65f78933 100644 +--- a/include/linux/minmax.h ++++ b/include/linux/minmax.h +@@ -270,4 +270,11 @@ static inline bool in_range32(u32 val, u32 start, u32 len) + #define swap(a, b) \ + do { typeof(a) __tmp = (a); (a) = (b); (b) = __tmp; } while (0) + ++/* ++ * Use these carefully: no type checking, and uses the arguments ++ * multiple times. Use for obvious constants only. ++ */ ++#define MIN_T(type,a,b) __cmp(min,(type)(a),(type)(b)) ++#define MAX_T(type,a,b) __cmp(max,(type)(a),(type)(b)) ++ + #endif /* _LINUX_MINMAX_H */ +diff --git a/include/linux/pageblock-flags.h b/include/linux/pageblock-flags.h +index 547e82cdc89a..fc6b9c87cb0a 100644 +--- a/include/linux/pageblock-flags.h ++++ b/include/linux/pageblock-flags.h +@@ -41,13 +41,13 @@ extern unsigned int pageblock_order; + * Huge pages are a constant size, but don't exceed the maximum allocation + * granularity. + */ +-#define pageblock_order min_t(unsigned int, HUGETLB_PAGE_ORDER, MAX_PAGE_ORDER) ++#define pageblock_order MIN_T(unsigned int, HUGETLB_PAGE_ORDER, MAX_PAGE_ORDER) + + #endif /* CONFIG_HUGETLB_PAGE_SIZE_VARIABLE */ + + #elif defined(CONFIG_TRANSPARENT_HUGEPAGE) + +-#define pageblock_order min_t(unsigned int, HPAGE_PMD_ORDER, MAX_PAGE_ORDER) ++#define pageblock_order MIN_T(unsigned int, HPAGE_PMD_ORDER, MAX_PAGE_ORDER) + + #else /* CONFIG_TRANSPARENT_HUGEPAGE */ + +diff --git a/include/linux/pagemap.h b/include/linux/pagemap.h +index a0a026d2d244..8bece21a8998 100644 +--- a/include/linux/pagemap.h ++++ b/include/linux/pagemap.h +@@ -1281,7 +1281,7 @@ struct readahead_control { + ._index = i, \ + } + +-#define VM_READAHEAD_PAGES (SZ_128K / PAGE_SIZE) ++#define VM_READAHEAD_PAGES (SZ_8M / PAGE_SIZE) + + void page_cache_ra_unbounded(struct readahead_control *, + unsigned long nr_to_read, unsigned long lookahead_count); +diff --git a/include/linux/user_namespace.h b/include/linux/user_namespace.h +index 6030a8235617..60b7fe5fa74a 100644 +--- a/include/linux/user_namespace.h ++++ b/include/linux/user_namespace.h +@@ -156,6 +156,8 @@ static inline void set_userns_rlimit_max(struct user_namespace *ns, + + #ifdef CONFIG_USER_NS + ++extern int unprivileged_userns_clone; ++ + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) + { + if (ns) +@@ -189,6 +191,8 @@ extern bool current_in_userns(const struct user_namespace *target_ns); + struct ns_common *ns_get_owner(struct ns_common *ns); + #else + ++#define unprivileged_userns_clone 0 ++ + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) + { + return &init_user_ns; +diff --git a/init/Kconfig b/init/Kconfig +index d8a971b804d3..bfc033b53242 100644 +--- a/init/Kconfig ++++ b/init/Kconfig +@@ -132,6 +132,10 @@ config THREAD_INFO_IN_TASK + + menu "General setup" + ++config CACHY ++ bool "Some kernel tweaks by CachyOS" ++ default y ++ + config BROKEN + bool + +@@ -1251,6 +1255,22 @@ config USER_NS + + If unsure, say N. + ++config USER_NS_UNPRIVILEGED ++ bool "Allow unprivileged users to create namespaces" ++ default y ++ depends on USER_NS ++ help ++ When disabled, unprivileged users will not be able to create ++ new namespaces. Allowing users to create their own namespaces ++ has been part of several recent local privilege escalation ++ exploits, so if you need user namespaces but are ++ paranoid^Wsecurity-conscious you want to disable this. ++ ++ This setting can be overridden at runtime via the ++ kernel.unprivileged_userns_clone sysctl. ++ ++ If unsure, say Y. ++ + config PID_NS + bool "PID Namespaces" + default y +@@ -1393,6 +1413,12 @@ config CC_OPTIMIZE_FOR_PERFORMANCE + with the "-O2" compiler flag for best performance and most + helpful compile-time warnings. + ++config CC_OPTIMIZE_FOR_PERFORMANCE_O3 ++ bool "Optimize more for performance (-O3)" ++ help ++ Choosing this option will pass "-O3" to your compiler to optimize ++ the kernel yet more for performance. ++ + config CC_OPTIMIZE_FOR_SIZE + bool "Optimize for size (-Os)" + help +diff --git a/kernel/Kconfig.hz b/kernel/Kconfig.hz +index 38ef6d06888e..0f78364efd4f 100644 +--- a/kernel/Kconfig.hz ++++ b/kernel/Kconfig.hz +@@ -40,6 +40,27 @@ choice + on SMP and NUMA systems and exactly dividing by both PAL and + NTSC frame rates for video and multimedia work. + ++ config HZ_500 ++ bool "500 HZ" ++ help ++ 500 Hz is a balanced timer frequency. Provides fast interactivity ++ on desktops with good smoothness without increasing CPU power ++ consumption and sacrificing the battery life on laptops. ++ ++ config HZ_600 ++ bool "600 HZ" ++ help ++ 600 Hz is a balanced timer frequency. Provides fast interactivity ++ on desktops with good smoothness without increasing CPU power ++ consumption and sacrificing the battery life on laptops. ++ ++ config HZ_750 ++ bool "750 HZ" ++ help ++ 750 Hz is a balanced timer frequency. Provides fast interactivity ++ on desktops with good smoothness without increasing CPU power ++ consumption and sacrificing the battery life on laptops. ++ + config HZ_1000 + bool "1000 HZ" + help +@@ -53,6 +74,9 @@ config HZ + default 100 if HZ_100 + default 250 if HZ_250 + default 300 if HZ_300 ++ default 500 if HZ_500 ++ default 600 if HZ_600 ++ default 750 if HZ_750 + default 1000 if HZ_1000 + + config SCHED_HRTICK +diff --git a/kernel/fork.c b/kernel/fork.c +index 99076dbe27d8..18750b83c564 100644 +--- a/kernel/fork.c ++++ b/kernel/fork.c +@@ -104,6 +104,10 @@ + #include + #include + ++#ifdef CONFIG_USER_NS ++#include ++#endif ++ + #include + #include + #include +@@ -2154,6 +2158,10 @@ __latent_entropy struct task_struct *copy_process( + if ((clone_flags & (CLONE_NEWUSER|CLONE_FS)) == (CLONE_NEWUSER|CLONE_FS)) + return ERR_PTR(-EINVAL); + ++ if ((clone_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) ++ if (!capable(CAP_SYS_ADMIN)) ++ return ERR_PTR(-EPERM); ++ + /* + * Thread groups must share signals as well, and detached threads + * can only be started up within the thread group. +@@ -3301,6 +3309,12 @@ int ksys_unshare(unsigned long unshare_flags) + if (unshare_flags & CLONE_NEWNS) + unshare_flags |= CLONE_FS; + ++ if ((unshare_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) { ++ err = -EPERM; ++ if (!capable(CAP_SYS_ADMIN)) ++ goto bad_unshare_out; ++ } ++ + err = check_unshare_flags(unshare_flags); + if (err) + goto bad_unshare_out; +diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c +index 483c137b9d3d..1fee282d40aa 100644 +--- a/kernel/sched/fair.c ++++ b/kernel/sched/fair.c +@@ -73,10 +73,19 @@ unsigned int sysctl_sched_tunable_scaling = SCHED_TUNABLESCALING_LOG; + * + * (default: 0.75 msec * (1 + ilog(ncpus)), units: nanoseconds) + */ ++#ifdef CONFIG_CACHY ++unsigned int sysctl_sched_base_slice = 350000ULL; ++static unsigned int normalized_sysctl_sched_base_slice = 350000ULL; ++#else + unsigned int sysctl_sched_base_slice = 750000ULL; + static unsigned int normalized_sysctl_sched_base_slice = 750000ULL; ++#endif + ++#ifdef CONFIG_CACHY ++const_debug unsigned int sysctl_sched_migration_cost = 300000UL; ++#else + const_debug unsigned int sysctl_sched_migration_cost = 500000UL; ++#endif + + static int __init setup_sched_thermal_decay_shift(char *str) + { +@@ -121,8 +130,12 @@ int __weak arch_asym_cpu_priority(int cpu) + * + * (default: 5 msec, units: microseconds) + */ ++#ifdef CONFIG_CACHY ++static unsigned int sysctl_sched_cfs_bandwidth_slice = 3000UL; ++#else + static unsigned int sysctl_sched_cfs_bandwidth_slice = 5000UL; + #endif ++#endif + + #ifdef CONFIG_NUMA_BALANCING + /* Restrict the NUMA promotion throughput (MB/s) for each target node. */ +diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h +index 38aeedd8a6cc..556466836cd5 100644 +--- a/kernel/sched/sched.h ++++ b/kernel/sched/sched.h +@@ -2544,7 +2544,7 @@ extern void deactivate_task(struct rq *rq, struct task_struct *p, int flags); + + extern void wakeup_preempt(struct rq *rq, struct task_struct *p, int flags); + +-#ifdef CONFIG_PREEMPT_RT ++#if defined(CONFIG_PREEMPT_RT) || defined(CONFIG_CACHY) + #define SCHED_NR_MIGRATE_BREAK 8 + #else + #define SCHED_NR_MIGRATE_BREAK 32 +diff --git a/kernel/sysctl.c b/kernel/sysctl.c +index e0b917328cf9..e70ae9c11dea 100644 +--- a/kernel/sysctl.c ++++ b/kernel/sysctl.c +@@ -80,6 +80,9 @@ + #ifdef CONFIG_RT_MUTEXES + #include + #endif ++#ifdef CONFIG_USER_NS ++#include ++#endif + + /* shared constants to be used in various sysctls */ + const int sysctl_vals[] = { 0, 1, 2, 3, 4, 100, 200, 1000, 3000, INT_MAX, 65535, -1 }; +@@ -1623,6 +1626,15 @@ static struct ctl_table kern_table[] = { + .mode = 0644, + .proc_handler = proc_dointvec, + }, ++#ifdef CONFIG_USER_NS ++ { ++ .procname = "unprivileged_userns_clone", ++ .data = &unprivileged_userns_clone, ++ .maxlen = sizeof(int), ++ .mode = 0644, ++ .proc_handler = proc_dointvec, ++ }, ++#endif + #ifdef CONFIG_PROC_SYSCTL + { + .procname = "tainted", +diff --git a/kernel/user_namespace.c b/kernel/user_namespace.c +index 0b0b95418b16..c4b835b91fc0 100644 +--- a/kernel/user_namespace.c ++++ b/kernel/user_namespace.c +@@ -22,6 +22,13 @@ + #include + #include + ++/* sysctl */ ++#ifdef CONFIG_USER_NS_UNPRIVILEGED ++int unprivileged_userns_clone = 1; ++#else ++int unprivileged_userns_clone; ++#endif ++ + static struct kmem_cache *user_ns_cachep __ro_after_init; + static DEFINE_MUTEX(userns_state_mutex); + +diff --git a/mm/Kconfig b/mm/Kconfig +index b4cb45255a54..8635b3b24739 100644 +--- a/mm/Kconfig ++++ b/mm/Kconfig +@@ -613,7 +613,7 @@ config COMPACTION + config COMPACT_UNEVICTABLE_DEFAULT + int + depends on COMPACTION +- default 0 if PREEMPT_RT ++ default 0 if PREEMPT_RT || CACHY + default 1 + + # +diff --git a/mm/compaction.c b/mm/compaction.c +index 739b1bf3d637..3a4269c02fb2 100644 +--- a/mm/compaction.c ++++ b/mm/compaction.c +@@ -1950,7 +1950,11 @@ static int sysctl_compact_unevictable_allowed __read_mostly = CONFIG_COMPACT_UNE + * aggressively the kernel should compact memory in the + * background. It takes values in the range [0, 100]. + */ ++#ifdef CONFIG_CACHY ++static unsigned int __read_mostly sysctl_compaction_proactiveness; ++#else + static unsigned int __read_mostly sysctl_compaction_proactiveness = 20; ++#endif + static int sysctl_extfrag_threshold = 500; + static int __read_mostly sysctl_compact_memory; + +diff --git a/mm/huge_memory.c b/mm/huge_memory.c +index 5f32a196a612..99832eb64739 100644 +--- a/mm/huge_memory.c ++++ b/mm/huge_memory.c +@@ -63,7 +63,11 @@ unsigned long transparent_hugepage_flags __read_mostly = + #ifdef CONFIG_TRANSPARENT_HUGEPAGE_MADVISE + (1<> (20 - PAGE_SHIFT); + + /* Use a smaller cluster for small-memory machines */ +@@ -1122,4 +1126,5 @@ void __init swap_setup(void) + * Right now other parts of the system means that we + * _really_ don't want to cluster much more + */ ++#endif + } +diff --git a/mm/vmpressure.c b/mm/vmpressure.c +index bd5183dfd879..3a410f53a07c 100644 +--- a/mm/vmpressure.c ++++ b/mm/vmpressure.c +@@ -43,7 +43,11 @@ static const unsigned long vmpressure_win = SWAP_CLUSTER_MAX * 16; + * essence, they are percents: the higher the value, the more number + * unsuccessful reclaims there were. + */ ++#ifdef CONFIG_CACHY ++static const unsigned int vmpressure_level_med = 65; ++#else + static const unsigned int vmpressure_level_med = 60; ++#endif + static const unsigned int vmpressure_level_critical = 95; + + /* +diff --git a/mm/vmscan.c b/mm/vmscan.c +index 68ac33bea3a3..9ede4f0c1c0e 100644 +--- a/mm/vmscan.c ++++ b/mm/vmscan.c +@@ -191,7 +191,11 @@ struct scan_control { + /* + * From 0 .. 200. Higher means more swappy. + */ ++#ifdef CONFIG_CACHY ++int vm_swappiness = 20; ++#else + int vm_swappiness = 60; ++#endif + + #ifdef CONFIG_MEMCG + +@@ -3973,7 +3977,11 @@ static bool lruvec_is_reclaimable(struct lruvec *lruvec, struct scan_control *sc + } + + /* to protect the working set of the last N jiffies */ ++#ifdef CONFIG_CACHY ++static unsigned long lru_gen_min_ttl __read_mostly = 1000; ++#else + static unsigned long lru_gen_min_ttl __read_mostly; ++#endif + + static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc) + { +diff --git a/scripts/Makefile.package b/scripts/Makefile.package +index bf016af8bf8a..4a80584ec771 100644 +--- a/scripts/Makefile.package ++++ b/scripts/Makefile.package +@@ -141,6 +141,19 @@ snap-pkg: + cd $(objtree)/snap && \ + snapcraft --target-arch=$(UTS_MACHINE) + ++# pacman-pkg ++# --------------------------------------------------------------------------- ++ ++PHONY += pacman-pkg ++pacman-pkg: ++ @ln -srf $(srctree)/scripts/package/PKGBUILD $(objtree)/PKGBUILD ++ +objtree="$(realpath $(objtree))" \ ++ BUILDDIR="$(realpath $(objtree))/pacman" \ ++ CARCH="$(UTS_MACHINE)" \ ++ KBUILD_MAKEFLAGS="$(MAKEFLAGS)" \ ++ KBUILD_REVISION="$(shell $(srctree)/scripts/build-version)" \ ++ makepkg $(MAKEPKGOPTS) ++ + # dir-pkg tar*-pkg - tarball targets + # --------------------------------------------------------------------------- + +@@ -221,6 +234,7 @@ help: + @echo ' bindeb-pkg - Build only the binary kernel deb package' + @echo ' snap-pkg - Build only the binary kernel snap package' + @echo ' (will connect to external hosts)' ++ @echo ' pacman-pkg - Build only the binary kernel pacman package' + @echo ' dir-pkg - Build the kernel as a plain directory structure' + @echo ' tar-pkg - Build the kernel as an uncompressed tarball' + @echo ' targz-pkg - Build the kernel as a gzip compressed tarball' +diff --git a/scripts/package/PKGBUILD b/scripts/package/PKGBUILD +new file mode 100644 +index 000000000000..663ce300dd06 +--- /dev/null ++++ b/scripts/package/PKGBUILD +@@ -0,0 +1,108 @@ ++# SPDX-License-Identifier: GPL-2.0-only ++# Maintainer: Thomas Weißschuh ++# Contributor: Jan Alexander Steffens (heftig) ++ ++pkgbase=${PACMAN_PKGBASE:-linux-upstream} ++pkgname=("${pkgbase}" "${pkgbase}-api-headers") ++if grep -q CONFIG_MODULES=y include/config/auto.conf; then ++ pkgname+=("${pkgbase}-headers") ++fi ++pkgver="${KERNELRELEASE//-/_}" ++# The PKGBUILD is evaluated multiple times. ++# Running scripts/build-version from here would introduce inconsistencies. ++pkgrel="${KBUILD_REVISION}" ++pkgdesc='Upstream Linux' ++url='https://www.kernel.org/' ++# Enable flexible cross-compilation ++arch=(${CARCH}) ++license=(GPL-2.0-only) ++makedepends=( ++ bc ++ bison ++ cpio ++ flex ++ gettext ++ kmod ++ libelf ++ openssl ++ pahole ++ perl ++ python ++ rsync ++ tar ++) ++options=(!debug !strip !buildflags !makeflags) ++ ++build() { ++ # MAKEFLAGS from makepkg.conf override the ones inherited from kbuild. ++ # Bypass this override with a custom variable. ++ export MAKEFLAGS="${KBUILD_MAKEFLAGS}" ++ cd "${objtree}" ++ ++ ${MAKE} KERNELRELEASE="${KERNELRELEASE}" KBUILD_BUILD_VERSION="${pkgrel}" ++} ++ ++_package() { ++ pkgdesc="The ${pkgdesc} kernel and modules" ++ ++ export MAKEFLAGS="${KBUILD_MAKEFLAGS}" ++ cd "${objtree}" ++ local modulesdir="${pkgdir}/usr/${MODLIB}" ++ ++ echo "Installing boot image..." ++ # systemd expects to find the kernel here to allow hibernation ++ # https://github.com/systemd/systemd/commit/edda44605f06a41fb86b7ab8128dcf99161d2344 ++ install -Dm644 "$(${MAKE} -s image_name)" "${modulesdir}/vmlinuz" ++ ++ # Used by mkinitcpio to name the kernel ++ echo "${pkgbase}" > "${modulesdir}/pkgbase" ++ ++ echo "Installing modules..." ++ ${MAKE} INSTALL_MOD_PATH="${pkgdir}/usr" INSTALL_MOD_STRIP=1 \ ++ DEPMOD=true modules_install ++ ++ if [ -d "${srctree}/arch/${SRCARCH}/boot/dts" ]; then ++ echo "Installing dtbs..." ++ ${MAKE} INSTALL_DTBS_PATH="${modulesdir}/dtb" dtbs_install ++ fi ++ ++ # remove build link, will be part of -headers package ++ rm -f "${modulesdir}/build" ++} ++ ++_package-headers() { ++ pkgdesc="Headers and scripts for building modules for the ${pkgdesc} kernel" ++ ++ export MAKEFLAGS="${KBUILD_MAKEFLAGS}" ++ cd "${objtree}" ++ local builddir="${pkgdir}/usr/${MODLIB}/build" ++ ++ echo "Installing build files..." ++ "${srctree}/scripts/package/install-extmod-build" "${builddir}" ++ ++ echo "Installing System.map and config..." ++ cp System.map "${builddir}/System.map" ++ cp .config "${builddir}/.config" ++ ++ echo "Adding symlink..." ++ mkdir -p "${pkgdir}/usr/src" ++ ln -sr "${builddir}" "${pkgdir}/usr/src/${pkgbase}" ++} ++ ++_package-api-headers() { ++ pkgdesc="Kernel headers sanitized for use in userspace" ++ provides=(linux-api-headers) ++ conflicts=(linux-api-headers) ++ ++ export MAKEFLAGS="${KBUILD_MAKEFLAGS}" ++ cd "${objtree}" ++ ++ ${MAKE} headers_install INSTALL_HDR_PATH="${pkgdir}/usr" ++} ++ ++for _p in "${pkgname[@]}"; do ++ eval "package_$_p() { ++ $(declare -f "_package${_p#$pkgbase}") ++ _package${_p#$pkgbase} ++ }" ++done +-- +2.46.0 + +From 7f81e97ab2c94fec90c410ca66e5e5382769cb9f Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:19:12 +0200 +Subject: [PATCH 05/12] crypto + +Signed-off-by: Peter Jung +--- + arch/x86/crypto/Kconfig | 1 + + arch/x86/crypto/Makefile | 8 +- + arch/x86/crypto/aes-gcm-aesni-x86_64.S | 1128 +++++++++ + arch/x86/crypto/aes-gcm-avx10-x86_64.S | 1222 ++++++++++ + arch/x86/crypto/aesni-intel_asm.S | 1503 +----------- + arch/x86/crypto/aesni-intel_avx-x86_64.S | 2804 ---------------------- + arch/x86/crypto/aesni-intel_glue.c | 1269 ++++++---- + 7 files changed, 3125 insertions(+), 4810 deletions(-) + create mode 100644 arch/x86/crypto/aes-gcm-aesni-x86_64.S + create mode 100644 arch/x86/crypto/aes-gcm-avx10-x86_64.S + delete mode 100644 arch/x86/crypto/aesni-intel_avx-x86_64.S + +diff --git a/arch/x86/crypto/Kconfig b/arch/x86/crypto/Kconfig +index c9e59589a1ce..24875e6295f2 100644 +--- a/arch/x86/crypto/Kconfig ++++ b/arch/x86/crypto/Kconfig +@@ -18,6 +18,7 @@ config CRYPTO_AES_NI_INTEL + depends on X86 + select CRYPTO_AEAD + select CRYPTO_LIB_AES ++ select CRYPTO_LIB_GF128MUL + select CRYPTO_ALGAPI + select CRYPTO_SKCIPHER + select CRYPTO_SIMD +diff --git a/arch/x86/crypto/Makefile b/arch/x86/crypto/Makefile +index 9c5ce5613738..53b4a277809e 100644 +--- a/arch/x86/crypto/Makefile ++++ b/arch/x86/crypto/Makefile +@@ -48,8 +48,12 @@ chacha-x86_64-$(CONFIG_AS_AVX512) += chacha-avx512vl-x86_64.o + + obj-$(CONFIG_CRYPTO_AES_NI_INTEL) += aesni-intel.o + aesni-intel-y := aesni-intel_asm.o aesni-intel_glue.o +-aesni-intel-$(CONFIG_64BIT) += aesni-intel_avx-x86_64.o \ +- aes_ctrby8_avx-x86_64.o aes-xts-avx-x86_64.o ++aesni-intel-$(CONFIG_64BIT) += aes_ctrby8_avx-x86_64.o \ ++ aes-gcm-aesni-x86_64.o \ ++ aes-xts-avx-x86_64.o ++ifeq ($(CONFIG_AS_VAES)$(CONFIG_AS_VPCLMULQDQ),yy) ++aesni-intel-$(CONFIG_64BIT) += aes-gcm-avx10-x86_64.o ++endif + + obj-$(CONFIG_CRYPTO_SHA1_SSSE3) += sha1-ssse3.o + sha1-ssse3-y := sha1_avx2_x86_64_asm.o sha1_ssse3_asm.o sha1_ssse3_glue.o +diff --git a/arch/x86/crypto/aes-gcm-aesni-x86_64.S b/arch/x86/crypto/aes-gcm-aesni-x86_64.S +new file mode 100644 +index 000000000000..45940e2883a0 +--- /dev/null ++++ b/arch/x86/crypto/aes-gcm-aesni-x86_64.S +@@ -0,0 +1,1128 @@ ++/* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */ ++// ++// AES-NI optimized AES-GCM for x86_64 ++// ++// Copyright 2024 Google LLC ++// ++// Author: Eric Biggers ++// ++//------------------------------------------------------------------------------ ++// ++// This file is dual-licensed, meaning that you can use it under your choice of ++// either of the following two licenses: ++// ++// Licensed under the Apache License 2.0 (the "License"). You may obtain a copy ++// of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// or ++// ++// Redistribution and use in source and binary forms, with or without ++// modification, are permitted provided that the following conditions are met: ++// ++// 1. Redistributions of source code must retain the above copyright notice, ++// this list of conditions and the following disclaimer. ++// ++// 2. Redistributions in binary form must reproduce the above copyright ++// notice, this list of conditions and the following disclaimer in the ++// documentation and/or other materials provided with the distribution. ++// ++// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" ++// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ++// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ++// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE ++// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR ++// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF ++// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS ++// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN ++// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ++// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ++// POSSIBILITY OF SUCH DAMAGE. ++// ++//------------------------------------------------------------------------------ ++// ++// This file implements AES-GCM (Galois/Counter Mode) for x86_64 CPUs that ++// support the original set of AES instructions, i.e. AES-NI. Two ++// implementations are provided, one that uses AVX and one that doesn't. They ++// are very similar, being generated by the same macros. The only difference is ++// that the AVX implementation takes advantage of VEX-coded instructions in some ++// places to avoid some 'movdqu' and 'movdqa' instructions. The AVX ++// implementation does *not* use 256-bit vectors, as AES is not supported on ++// 256-bit vectors until the VAES feature (which this file doesn't target). ++// ++// The specific CPU feature prerequisites are AES-NI and PCLMULQDQ, plus SSE4.1 ++// for the *_aesni functions or AVX for the *_aesni_avx ones. (But it seems ++// there are no CPUs that support AES-NI without also PCLMULQDQ and SSE4.1.) ++// ++// The design generally follows that of aes-gcm-avx10-x86_64.S, and that file is ++// more thoroughly commented. This file has the following notable changes: ++// ++// - The vector length is fixed at 128-bit, i.e. xmm registers. This means ++// there is only one AES block (and GHASH block) per register. ++// ++// - Without AVX512 / AVX10, only 16 SIMD registers are available instead of ++// 32. We work around this by being much more careful about using ++// registers, relying heavily on loads to load values as they are needed. ++// ++// - Masking is not available either. We work around this by implementing ++// partial block loads and stores using overlapping scalar loads and stores ++// combined with shifts and SSE4.1 insertion and extraction instructions. ++// ++// - The main loop is organized differently due to the different design ++// constraints. First, with just one AES block per SIMD register, on some ++// CPUs 4 registers don't saturate the 'aesenc' throughput. We therefore ++// do an 8-register wide loop. Considering that and the fact that we have ++// just 16 SIMD registers to work with, it's not feasible to cache AES ++// round keys and GHASH key powers in registers across loop iterations. ++// That's not ideal, but also not actually that bad, since loads can run in ++// parallel with other instructions. Significantly, this also makes it ++// possible to roll up the inner loops, relying on hardware loop unrolling ++// instead of software loop unrolling, greatly reducing code size. ++// ++// - We implement the GHASH multiplications in the main loop using Karatsuba ++// multiplication instead of schoolbook multiplication. This saves one ++// pclmulqdq instruction per block, at the cost of one 64-bit load, one ++// pshufd, and 0.25 pxors per block. (This is without the three-argument ++// XOR support that would be provided by AVX512 / AVX10, which would be ++// more beneficial to schoolbook than Karatsuba.) ++// ++// As a rough approximation, we can assume that Karatsuba multiplication is ++// faster than schoolbook multiplication in this context if one pshufd and ++// 0.25 pxors are cheaper than a pclmulqdq. (We assume that the 64-bit ++// load is "free" due to running in parallel with arithmetic instructions.) ++// This is true on AMD CPUs, including all that support pclmulqdq up to at ++// least Zen 3. It's also true on older Intel CPUs: Westmere through ++// Haswell on the Core side, and Silvermont through Goldmont Plus on the ++// low-power side. On some of these CPUs, pclmulqdq is quite slow, and the ++// benefit of Karatsuba should be substantial. On newer Intel CPUs, ++// schoolbook multiplication should be faster, but only marginally. ++// ++// Not all these CPUs were available to be tested. However, benchmarks on ++// available CPUs suggest that this approximation is plausible. Switching ++// to Karatsuba showed negligible change (< 1%) on Intel Broadwell, ++// Skylake, and Cascade Lake, but it improved AMD Zen 1-3 by 6-7%. ++// Considering that and the fact that Karatsuba should be even more ++// beneficial on older Intel CPUs, it seems like the right choice here. ++// ++// An additional 0.25 pclmulqdq per block (2 per 8 blocks) could be ++// saved by using a multiplication-less reduction method. We don't do that ++// because it would require a large number of shift and xor instructions, ++// making it less worthwhile and likely harmful on newer CPUs. ++// ++// It does make sense to sometimes use a different reduction optimization ++// that saves a pclmulqdq, though: precompute the hash key times x^64, and ++// multiply the low half of the data block by the hash key with the extra ++// factor of x^64. This eliminates one step of the reduction. However, ++// this is incompatible with Karatsuba multiplication. Therefore, for ++// multi-block processing we use Karatsuba multiplication with a regular ++// reduction. For single-block processing, we use the x^64 optimization. ++ ++#include ++ ++.section .rodata ++.p2align 4 ++.Lbswap_mask: ++ .octa 0x000102030405060708090a0b0c0d0e0f ++.Lgfpoly: ++ .quad 0xc200000000000000 ++.Lone: ++ .quad 1 ++.Lgfpoly_and_internal_carrybit: ++ .octa 0xc2000000000000010000000000000001 ++ // Loading 16 bytes from '.Lzeropad_mask + 16 - len' produces a mask of ++ // 'len' 0xff bytes and the rest zeroes. ++.Lzeropad_mask: ++ .octa 0xffffffffffffffffffffffffffffffff ++ .octa 0 ++ ++// Offsets in struct aes_gcm_key_aesni ++#define OFFSETOF_AESKEYLEN 480 ++#define OFFSETOF_H_POWERS 496 ++#define OFFSETOF_H_POWERS_XORED 624 ++#define OFFSETOF_H_TIMES_X64 688 ++ ++.text ++ ++// Do a vpclmulqdq, or fall back to a movdqa and a pclmulqdq. The fallback ++// assumes that all operands are distinct and that any mem operand is aligned. ++.macro _vpclmulqdq imm, src1, src2, dst ++.if USE_AVX ++ vpclmulqdq \imm, \src1, \src2, \dst ++.else ++ movdqa \src2, \dst ++ pclmulqdq \imm, \src1, \dst ++.endif ++.endm ++ ++// Do a vpshufb, or fall back to a movdqa and a pshufb. The fallback assumes ++// that all operands are distinct and that any mem operand is aligned. ++.macro _vpshufb src1, src2, dst ++.if USE_AVX ++ vpshufb \src1, \src2, \dst ++.else ++ movdqa \src2, \dst ++ pshufb \src1, \dst ++.endif ++.endm ++ ++// Do a vpand, or fall back to a movdqu and a pand. The fallback assumes that ++// all operands are distinct. ++.macro _vpand src1, src2, dst ++.if USE_AVX ++ vpand \src1, \src2, \dst ++.else ++ movdqu \src1, \dst ++ pand \src2, \dst ++.endif ++.endm ++ ++// XOR the unaligned memory operand \mem into the xmm register \reg. \tmp must ++// be a temporary xmm register. ++.macro _xor_mem_to_reg mem, reg, tmp ++.if USE_AVX ++ vpxor \mem, \reg, \reg ++.else ++ movdqu \mem, \tmp ++ pxor \tmp, \reg ++.endif ++.endm ++ ++// Test the unaligned memory operand \mem against the xmm register \reg. \tmp ++// must be a temporary xmm register. ++.macro _test_mem mem, reg, tmp ++.if USE_AVX ++ vptest \mem, \reg ++.else ++ movdqu \mem, \tmp ++ ptest \tmp, \reg ++.endif ++.endm ++ ++// Load 1 <= %ecx <= 15 bytes from the pointer \src into the xmm register \dst ++// and zeroize any remaining bytes. Clobbers %rax, %rcx, and \tmp{64,32}. ++.macro _load_partial_block src, dst, tmp64, tmp32 ++ sub $8, %ecx // LEN - 8 ++ jle .Lle8\@ ++ ++ // Load 9 <= LEN <= 15 bytes. ++ movq (\src), \dst // Load first 8 bytes ++ mov (\src, %rcx), %rax // Load last 8 bytes ++ neg %ecx ++ shl $3, %ecx ++ shr %cl, %rax // Discard overlapping bytes ++ pinsrq $1, %rax, \dst ++ jmp .Ldone\@ ++ ++.Lle8\@: ++ add $4, %ecx // LEN - 4 ++ jl .Llt4\@ ++ ++ // Load 4 <= LEN <= 8 bytes. ++ mov (\src), %eax // Load first 4 bytes ++ mov (\src, %rcx), \tmp32 // Load last 4 bytes ++ jmp .Lcombine\@ ++ ++.Llt4\@: ++ // Load 1 <= LEN <= 3 bytes. ++ add $2, %ecx // LEN - 2 ++ movzbl (\src), %eax // Load first byte ++ jl .Lmovq\@ ++ movzwl (\src, %rcx), \tmp32 // Load last 2 bytes ++.Lcombine\@: ++ shl $3, %ecx ++ shl %cl, \tmp64 ++ or \tmp64, %rax // Combine the two parts ++.Lmovq\@: ++ movq %rax, \dst ++.Ldone\@: ++.endm ++ ++// Store 1 <= %ecx <= 15 bytes from the xmm register \src to the pointer \dst. ++// Clobbers %rax, %rcx, and %rsi. ++.macro _store_partial_block src, dst ++ sub $8, %ecx // LEN - 8 ++ jl .Llt8\@ ++ ++ // Store 8 <= LEN <= 15 bytes. ++ pextrq $1, \src, %rax ++ mov %ecx, %esi ++ shl $3, %ecx ++ ror %cl, %rax ++ mov %rax, (\dst, %rsi) // Store last LEN - 8 bytes ++ movq \src, (\dst) // Store first 8 bytes ++ jmp .Ldone\@ ++ ++.Llt8\@: ++ add $4, %ecx // LEN - 4 ++ jl .Llt4\@ ++ ++ // Store 4 <= LEN <= 7 bytes. ++ pextrd $1, \src, %eax ++ mov %ecx, %esi ++ shl $3, %ecx ++ ror %cl, %eax ++ mov %eax, (\dst, %rsi) // Store last LEN - 4 bytes ++ movd \src, (\dst) // Store first 4 bytes ++ jmp .Ldone\@ ++ ++.Llt4\@: ++ // Store 1 <= LEN <= 3 bytes. ++ pextrb $0, \src, 0(\dst) ++ cmp $-2, %ecx // LEN - 4 == -2, i.e. LEN == 2? ++ jl .Ldone\@ ++ pextrb $1, \src, 1(\dst) ++ je .Ldone\@ ++ pextrb $2, \src, 2(\dst) ++.Ldone\@: ++.endm ++ ++// Do one step of GHASH-multiplying \a by \b and storing the reduced product in ++// \b. To complete all steps, this must be invoked with \i=0 through \i=9. ++// \a_times_x64 must contain \a * x^64 in reduced form, \gfpoly must contain the ++// .Lgfpoly constant, and \t0-\t1 must be temporary registers. ++.macro _ghash_mul_step i, a, a_times_x64, b, gfpoly, t0, t1 ++ ++ // MI = (a_L * b_H) + ((a*x^64)_L * b_L) ++.if \i == 0 ++ _vpclmulqdq $0x01, \a, \b, \t0 ++.elseif \i == 1 ++ _vpclmulqdq $0x00, \a_times_x64, \b, \t1 ++.elseif \i == 2 ++ pxor \t1, \t0 ++ ++ // HI = (a_H * b_H) + ((a*x^64)_H * b_L) ++.elseif \i == 3 ++ _vpclmulqdq $0x11, \a, \b, \t1 ++.elseif \i == 4 ++ pclmulqdq $0x10, \a_times_x64, \b ++.elseif \i == 5 ++ pxor \t1, \b ++.elseif \i == 6 ++ ++ // Fold MI into HI. ++ pshufd $0x4e, \t0, \t1 // Swap halves of MI ++.elseif \i == 7 ++ pclmulqdq $0x00, \gfpoly, \t0 // MI_L*(x^63 + x^62 + x^57) ++.elseif \i == 8 ++ pxor \t1, \b ++.elseif \i == 9 ++ pxor \t0, \b ++.endif ++.endm ++ ++// GHASH-multiply \a by \b and store the reduced product in \b. ++// See _ghash_mul_step for details. ++.macro _ghash_mul a, a_times_x64, b, gfpoly, t0, t1 ++.irp i, 0,1,2,3,4,5,6,7,8,9 ++ _ghash_mul_step \i, \a, \a_times_x64, \b, \gfpoly, \t0, \t1 ++.endr ++.endm ++ ++// GHASH-multiply \a by \b and add the unreduced product to \lo, \mi, and \hi. ++// This does Karatsuba multiplication and must be paired with _ghash_reduce. On ++// the first call, \lo, \mi, and \hi must be zero. \a_xored must contain the ++// two halves of \a XOR'd together, i.e. a_L + a_H. \b is clobbered. ++.macro _ghash_mul_noreduce a, a_xored, b, lo, mi, hi, t0 ++ ++ // LO += a_L * b_L ++ _vpclmulqdq $0x00, \a, \b, \t0 ++ pxor \t0, \lo ++ ++ // b_L + b_H ++ pshufd $0x4e, \b, \t0 ++ pxor \b, \t0 ++ ++ // HI += a_H * b_H ++ pclmulqdq $0x11, \a, \b ++ pxor \b, \hi ++ ++ // MI += (a_L + a_H) * (b_L + b_H) ++ pclmulqdq $0x00, \a_xored, \t0 ++ pxor \t0, \mi ++.endm ++ ++// Reduce the product from \lo, \mi, and \hi, and store the result in \dst. ++// This assumes that _ghash_mul_noreduce was used. ++.macro _ghash_reduce lo, mi, hi, dst, t0 ++ ++ movq .Lgfpoly(%rip), \t0 ++ ++ // MI += LO + HI (needed because we used Karatsuba multiplication) ++ pxor \lo, \mi ++ pxor \hi, \mi ++ ++ // Fold LO into MI. ++ pshufd $0x4e, \lo, \dst ++ pclmulqdq $0x00, \t0, \lo ++ pxor \dst, \mi ++ pxor \lo, \mi ++ ++ // Fold MI into HI. ++ pshufd $0x4e, \mi, \dst ++ pclmulqdq $0x00, \t0, \mi ++ pxor \hi, \dst ++ pxor \mi, \dst ++.endm ++ ++// Do the first step of the GHASH update of a set of 8 ciphertext blocks. ++// ++// The whole GHASH update does: ++// ++// GHASH_ACC = (blk0+GHASH_ACC)*H^8 + blk1*H^7 + blk2*H^6 + blk3*H^5 + ++// blk4*H^4 + blk5*H^3 + blk6*H^2 + blk7*H^1 ++// ++// This macro just does the first step: it does the unreduced multiplication ++// (blk0+GHASH_ACC)*H^8 and starts gathering the unreduced product in the xmm ++// registers LO, MI, and GHASH_ACC a.k.a. HI. It also zero-initializes the ++// inner block counter in %rax, which is a value that counts up by 8 for each ++// block in the set of 8 and is used later to index by 8*blknum and 16*blknum. ++// ++// To reduce the number of pclmulqdq instructions required, both this macro and ++// _ghash_update_continue_8x use Karatsuba multiplication instead of schoolbook ++// multiplication. See the file comment for more details about this choice. ++// ++// Both macros expect the ciphertext blocks blk[0-7] to be available at DST if ++// encrypting, or SRC if decrypting. They also expect the precomputed hash key ++// powers H^i and their XOR'd-together halves to be available in the struct ++// pointed to by KEY. Both macros clobber TMP[0-2]. ++.macro _ghash_update_begin_8x enc ++ ++ // Initialize the inner block counter. ++ xor %eax, %eax ++ ++ // Load the highest hash key power, H^8. ++ movdqa OFFSETOF_H_POWERS(KEY), TMP0 ++ ++ // Load the first ciphertext block and byte-reflect it. ++.if \enc ++ movdqu (DST), TMP1 ++.else ++ movdqu (SRC), TMP1 ++.endif ++ pshufb BSWAP_MASK, TMP1 ++ ++ // Add the GHASH accumulator to the ciphertext block to get the block ++ // 'b' that needs to be multiplied with the hash key power 'a'. ++ pxor TMP1, GHASH_ACC ++ ++ // b_L + b_H ++ pshufd $0x4e, GHASH_ACC, MI ++ pxor GHASH_ACC, MI ++ ++ // LO = a_L * b_L ++ _vpclmulqdq $0x00, TMP0, GHASH_ACC, LO ++ ++ // HI = a_H * b_H ++ pclmulqdq $0x11, TMP0, GHASH_ACC ++ ++ // MI = (a_L + a_H) * (b_L + b_H) ++ pclmulqdq $0x00, OFFSETOF_H_POWERS_XORED(KEY), MI ++.endm ++ ++// Continue the GHASH update of 8 ciphertext blocks as described above by doing ++// an unreduced multiplication of the next ciphertext block by the next lowest ++// key power and accumulating the result into LO, MI, and GHASH_ACC a.k.a. HI. ++.macro _ghash_update_continue_8x enc ++ add $8, %eax ++ ++ // Load the next lowest key power. ++ movdqa OFFSETOF_H_POWERS(KEY,%rax,2), TMP0 ++ ++ // Load the next ciphertext block and byte-reflect it. ++.if \enc ++ movdqu (DST,%rax,2), TMP1 ++.else ++ movdqu (SRC,%rax,2), TMP1 ++.endif ++ pshufb BSWAP_MASK, TMP1 ++ ++ // LO += a_L * b_L ++ _vpclmulqdq $0x00, TMP0, TMP1, TMP2 ++ pxor TMP2, LO ++ ++ // b_L + b_H ++ pshufd $0x4e, TMP1, TMP2 ++ pxor TMP1, TMP2 ++ ++ // HI += a_H * b_H ++ pclmulqdq $0x11, TMP0, TMP1 ++ pxor TMP1, GHASH_ACC ++ ++ // MI += (a_L + a_H) * (b_L + b_H) ++ movq OFFSETOF_H_POWERS_XORED(KEY,%rax), TMP1 ++ pclmulqdq $0x00, TMP1, TMP2 ++ pxor TMP2, MI ++.endm ++ ++// Reduce LO, MI, and GHASH_ACC a.k.a. HI into GHASH_ACC. This is similar to ++// _ghash_reduce, but it's hardcoded to use the registers of the main loop and ++// it uses the same register for HI and the destination. It's also divided into ++// two steps. TMP1 must be preserved across steps. ++// ++// One pshufd could be saved by shuffling MI and XOR'ing LO into it, instead of ++// shuffling LO, XOR'ing LO into MI, and shuffling MI. However, this would ++// increase the critical path length, and it seems to slightly hurt performance. ++.macro _ghash_update_end_8x_step i ++.if \i == 0 ++ movq .Lgfpoly(%rip), TMP1 ++ pxor LO, MI ++ pxor GHASH_ACC, MI ++ pshufd $0x4e, LO, TMP2 ++ pclmulqdq $0x00, TMP1, LO ++ pxor TMP2, MI ++ pxor LO, MI ++.elseif \i == 1 ++ pshufd $0x4e, MI, TMP2 ++ pclmulqdq $0x00, TMP1, MI ++ pxor TMP2, GHASH_ACC ++ pxor MI, GHASH_ACC ++.endif ++.endm ++ ++// void aes_gcm_precompute_##suffix(struct aes_gcm_key_aesni *key); ++// ++// Given the expanded AES key, derive the GHASH subkey and initialize the GHASH ++// related fields in the key struct. ++.macro _aes_gcm_precompute ++ ++ // Function arguments ++ .set KEY, %rdi ++ ++ // Additional local variables. ++ // %xmm0-%xmm1 and %rax are used as temporaries. ++ .set RNDKEYLAST_PTR, %rsi ++ .set H_CUR, %xmm2 ++ .set H_POW1, %xmm3 // H^1 ++ .set H_POW1_X64, %xmm4 // H^1 * x^64 ++ .set GFPOLY, %xmm5 ++ ++ // Encrypt an all-zeroes block to get the raw hash subkey. ++ movl OFFSETOF_AESKEYLEN(KEY), %eax ++ lea 6*16(KEY,%rax,4), RNDKEYLAST_PTR ++ movdqa (KEY), H_POW1 // Zero-th round key XOR all-zeroes block ++ lea 16(KEY), %rax ++1: ++ aesenc (%rax), H_POW1 ++ add $16, %rax ++ cmp %rax, RNDKEYLAST_PTR ++ jne 1b ++ aesenclast (RNDKEYLAST_PTR), H_POW1 ++ ++ // Preprocess the raw hash subkey as needed to operate on GHASH's ++ // bit-reflected values directly: reflect its bytes, then multiply it by ++ // x^-1 (using the backwards interpretation of polynomial coefficients ++ // from the GCM spec) or equivalently x^1 (using the alternative, ++ // natural interpretation of polynomial coefficients). ++ pshufb .Lbswap_mask(%rip), H_POW1 ++ movdqa H_POW1, %xmm0 ++ pshufd $0xd3, %xmm0, %xmm0 ++ psrad $31, %xmm0 ++ paddq H_POW1, H_POW1 ++ pand .Lgfpoly_and_internal_carrybit(%rip), %xmm0 ++ pxor %xmm0, H_POW1 ++ ++ // Store H^1. ++ movdqa H_POW1, OFFSETOF_H_POWERS+7*16(KEY) ++ ++ // Compute and store H^1 * x^64. ++ movq .Lgfpoly(%rip), GFPOLY ++ pshufd $0x4e, H_POW1, %xmm0 ++ _vpclmulqdq $0x00, H_POW1, GFPOLY, H_POW1_X64 ++ pxor %xmm0, H_POW1_X64 ++ movdqa H_POW1_X64, OFFSETOF_H_TIMES_X64(KEY) ++ ++ // Compute and store the halves of H^1 XOR'd together. ++ pxor H_POW1, %xmm0 ++ movq %xmm0, OFFSETOF_H_POWERS_XORED+7*8(KEY) ++ ++ // Compute and store the remaining key powers H^2 through H^8. ++ movdqa H_POW1, H_CUR ++ mov $6*8, %eax ++.Lprecompute_next\@: ++ // Compute H^i = H^{i-1} * H^1. ++ _ghash_mul H_POW1, H_POW1_X64, H_CUR, GFPOLY, %xmm0, %xmm1 ++ // Store H^i. ++ movdqa H_CUR, OFFSETOF_H_POWERS(KEY,%rax,2) ++ // Compute and store the halves of H^i XOR'd together. ++ pshufd $0x4e, H_CUR, %xmm0 ++ pxor H_CUR, %xmm0 ++ movq %xmm0, OFFSETOF_H_POWERS_XORED(KEY,%rax) ++ sub $8, %eax ++ jge .Lprecompute_next\@ ++ ++ RET ++.endm ++ ++// void aes_gcm_aad_update_aesni(const struct aes_gcm_key_aesni *key, ++// u8 ghash_acc[16], const u8 *aad, int aadlen); ++// ++// This function processes the AAD (Additional Authenticated Data) in GCM. ++// Using the key |key|, it updates the GHASH accumulator |ghash_acc| with the ++// data given by |aad| and |aadlen|. On the first call, |ghash_acc| must be all ++// zeroes. |aadlen| must be a multiple of 16, except on the last call where it ++// can be any length. The caller must do any buffering needed to ensure this. ++.macro _aes_gcm_aad_update ++ ++ // Function arguments ++ .set KEY, %rdi ++ .set GHASH_ACC_PTR, %rsi ++ .set AAD, %rdx ++ .set AADLEN, %ecx ++ // Note: _load_partial_block relies on AADLEN being in %ecx. ++ ++ // Additional local variables. ++ // %rax, %r10, and %xmm0-%xmm1 are used as temporary registers. ++ .set BSWAP_MASK, %xmm2 ++ .set GHASH_ACC, %xmm3 ++ .set H_POW1, %xmm4 // H^1 ++ .set H_POW1_X64, %xmm5 // H^1 * x^64 ++ .set GFPOLY, %xmm6 ++ ++ movdqa .Lbswap_mask(%rip), BSWAP_MASK ++ movdqu (GHASH_ACC_PTR), GHASH_ACC ++ movdqa OFFSETOF_H_POWERS+7*16(KEY), H_POW1 ++ movdqa OFFSETOF_H_TIMES_X64(KEY), H_POW1_X64 ++ movq .Lgfpoly(%rip), GFPOLY ++ ++ // Process the AAD one full block at a time. ++ sub $16, AADLEN ++ jl .Laad_loop_1x_done\@ ++.Laad_loop_1x\@: ++ movdqu (AAD), %xmm0 ++ pshufb BSWAP_MASK, %xmm0 ++ pxor %xmm0, GHASH_ACC ++ _ghash_mul H_POW1, H_POW1_X64, GHASH_ACC, GFPOLY, %xmm0, %xmm1 ++ add $16, AAD ++ sub $16, AADLEN ++ jge .Laad_loop_1x\@ ++.Laad_loop_1x_done\@: ++ // Check whether there is a partial block at the end. ++ add $16, AADLEN ++ jz .Laad_done\@ ++ ++ // Process a partial block of length 1 <= AADLEN <= 15. ++ // _load_partial_block assumes that %ecx contains AADLEN. ++ _load_partial_block AAD, %xmm0, %r10, %r10d ++ pshufb BSWAP_MASK, %xmm0 ++ pxor %xmm0, GHASH_ACC ++ _ghash_mul H_POW1, H_POW1_X64, GHASH_ACC, GFPOLY, %xmm0, %xmm1 ++ ++.Laad_done\@: ++ movdqu GHASH_ACC, (GHASH_ACC_PTR) ++ RET ++.endm ++ ++// Increment LE_CTR eight times to generate eight little-endian counter blocks, ++// swap each to big-endian, and store them in AESDATA[0-7]. Also XOR them with ++// the zero-th AES round key. Clobbers TMP0 and TMP1. ++.macro _ctr_begin_8x ++ movq .Lone(%rip), TMP0 ++ movdqa (KEY), TMP1 // zero-th round key ++.irp i, 0,1,2,3,4,5,6,7 ++ _vpshufb BSWAP_MASK, LE_CTR, AESDATA\i ++ pxor TMP1, AESDATA\i ++ paddd TMP0, LE_CTR ++.endr ++.endm ++ ++// Do a non-last round of AES on AESDATA[0-7] using \round_key. ++.macro _aesenc_8x round_key ++.irp i, 0,1,2,3,4,5,6,7 ++ aesenc \round_key, AESDATA\i ++.endr ++.endm ++ ++// Do the last round of AES on AESDATA[0-7] using \round_key. ++.macro _aesenclast_8x round_key ++.irp i, 0,1,2,3,4,5,6,7 ++ aesenclast \round_key, AESDATA\i ++.endr ++.endm ++ ++// XOR eight blocks from SRC with the keystream blocks in AESDATA[0-7], and ++// store the result to DST. Clobbers TMP0. ++.macro _xor_data_8x ++.irp i, 0,1,2,3,4,5,6,7 ++ _xor_mem_to_reg \i*16(SRC), AESDATA\i, tmp=TMP0 ++.endr ++.irp i, 0,1,2,3,4,5,6,7 ++ movdqu AESDATA\i, \i*16(DST) ++.endr ++.endm ++ ++// void aes_gcm_{enc,dec}_update_##suffix(const struct aes_gcm_key_aesni *key, ++// const u32 le_ctr[4], u8 ghash_acc[16], ++// const u8 *src, u8 *dst, int datalen); ++// ++// This macro generates a GCM encryption or decryption update function with the ++// above prototype (with \enc selecting which one). ++// ++// This function computes the next portion of the CTR keystream, XOR's it with ++// |datalen| bytes from |src|, and writes the resulting encrypted or decrypted ++// data to |dst|. It also updates the GHASH accumulator |ghash_acc| using the ++// next |datalen| ciphertext bytes. ++// ++// |datalen| must be a multiple of 16, except on the last call where it can be ++// any length. The caller must do any buffering needed to ensure this. Both ++// in-place and out-of-place en/decryption are supported. ++// ++// |le_ctr| must give the current counter in little-endian format. For a new ++// message, the low word of the counter must be 2. This function loads the ++// counter from |le_ctr| and increments the loaded counter as needed, but it ++// does *not* store the updated counter back to |le_ctr|. The caller must ++// update |le_ctr| if any more data segments follow. Internally, only the low ++// 32-bit word of the counter is incremented, following the GCM standard. ++.macro _aes_gcm_update enc ++ ++ // Function arguments ++ .set KEY, %rdi ++ .set LE_CTR_PTR, %rsi // Note: overlaps with usage as temp reg ++ .set GHASH_ACC_PTR, %rdx ++ .set SRC, %rcx ++ .set DST, %r8 ++ .set DATALEN, %r9d ++ .set DATALEN64, %r9 // Zero-extend DATALEN before using! ++ // Note: the code setting up for _load_partial_block assumes that SRC is ++ // in %rcx (and that DATALEN is *not* in %rcx). ++ ++ // Additional local variables ++ ++ // %rax and %rsi are used as temporary registers. Note: %rsi overlaps ++ // with LE_CTR_PTR, which is used only at the beginning. ++ ++ .set AESKEYLEN, %r10d // AES key length in bytes ++ .set AESKEYLEN64, %r10 ++ .set RNDKEYLAST_PTR, %r11 // Pointer to last AES round key ++ ++ // Put the most frequently used values in %xmm0-%xmm7 to reduce code ++ // size. (%xmm0-%xmm7 take fewer bytes to encode than %xmm8-%xmm15.) ++ .set TMP0, %xmm0 ++ .set TMP1, %xmm1 ++ .set TMP2, %xmm2 ++ .set LO, %xmm3 // Low part of unreduced product ++ .set MI, %xmm4 // Middle part of unreduced product ++ .set GHASH_ACC, %xmm5 // GHASH accumulator; in main loop also ++ // the high part of unreduced product ++ .set BSWAP_MASK, %xmm6 // Shuffle mask for reflecting bytes ++ .set LE_CTR, %xmm7 // Little-endian counter value ++ .set AESDATA0, %xmm8 ++ .set AESDATA1, %xmm9 ++ .set AESDATA2, %xmm10 ++ .set AESDATA3, %xmm11 ++ .set AESDATA4, %xmm12 ++ .set AESDATA5, %xmm13 ++ .set AESDATA6, %xmm14 ++ .set AESDATA7, %xmm15 ++ ++ movdqa .Lbswap_mask(%rip), BSWAP_MASK ++ movdqu (GHASH_ACC_PTR), GHASH_ACC ++ movdqu (LE_CTR_PTR), LE_CTR ++ ++ movl OFFSETOF_AESKEYLEN(KEY), AESKEYLEN ++ lea 6*16(KEY,AESKEYLEN64,4), RNDKEYLAST_PTR ++ ++ // If there are at least 8*16 bytes of data, then continue into the main ++ // loop, which processes 8*16 bytes of data per iteration. ++ // ++ // The main loop interleaves AES and GHASH to improve performance on ++ // CPUs that can execute these instructions in parallel. When ++ // decrypting, the GHASH input (the ciphertext) is immediately ++ // available. When encrypting, we instead encrypt a set of 8 blocks ++ // first and then GHASH those blocks while encrypting the next set of 8, ++ // repeat that as needed, and finally GHASH the last set of 8 blocks. ++ // ++ // Code size optimization: Prefer adding or subtracting -8*16 over 8*16, ++ // as this makes the immediate fit in a signed byte, saving 3 bytes. ++ add $-8*16, DATALEN ++ jl .Lcrypt_loop_8x_done\@ ++.if \enc ++ // Encrypt the first 8 plaintext blocks. ++ _ctr_begin_8x ++ lea 16(KEY), %rsi ++ .p2align 4 ++1: ++ movdqa (%rsi), TMP0 ++ _aesenc_8x TMP0 ++ add $16, %rsi ++ cmp %rsi, RNDKEYLAST_PTR ++ jne 1b ++ movdqa (%rsi), TMP0 ++ _aesenclast_8x TMP0 ++ _xor_data_8x ++ // Don't increment DST until the ciphertext blocks have been hashed. ++ sub $-8*16, SRC ++ add $-8*16, DATALEN ++ jl .Lghash_last_ciphertext_8x\@ ++.endif ++ ++ .p2align 4 ++.Lcrypt_loop_8x\@: ++ ++ // Generate the next set of 8 counter blocks and start encrypting them. ++ _ctr_begin_8x ++ lea 16(KEY), %rsi ++ ++ // Do a round of AES, and start the GHASH update of 8 ciphertext blocks ++ // by doing the unreduced multiplication for the first ciphertext block. ++ movdqa (%rsi), TMP0 ++ add $16, %rsi ++ _aesenc_8x TMP0 ++ _ghash_update_begin_8x \enc ++ ++ // Do 7 more rounds of AES, and continue the GHASH update by doing the ++ // unreduced multiplication for the remaining ciphertext blocks. ++ .p2align 4 ++1: ++ movdqa (%rsi), TMP0 ++ add $16, %rsi ++ _aesenc_8x TMP0 ++ _ghash_update_continue_8x \enc ++ cmp $7*8, %eax ++ jne 1b ++ ++ // Do the remaining AES rounds. ++ .p2align 4 ++1: ++ movdqa (%rsi), TMP0 ++ add $16, %rsi ++ _aesenc_8x TMP0 ++ cmp %rsi, RNDKEYLAST_PTR ++ jne 1b ++ ++ // Do the GHASH reduction and the last round of AES. ++ movdqa (RNDKEYLAST_PTR), TMP0 ++ _ghash_update_end_8x_step 0 ++ _aesenclast_8x TMP0 ++ _ghash_update_end_8x_step 1 ++ ++ // XOR the data with the AES-CTR keystream blocks. ++.if \enc ++ sub $-8*16, DST ++.endif ++ _xor_data_8x ++ sub $-8*16, SRC ++.if !\enc ++ sub $-8*16, DST ++.endif ++ add $-8*16, DATALEN ++ jge .Lcrypt_loop_8x\@ ++ ++.if \enc ++.Lghash_last_ciphertext_8x\@: ++ // Update GHASH with the last set of 8 ciphertext blocks. ++ _ghash_update_begin_8x \enc ++ .p2align 4 ++1: ++ _ghash_update_continue_8x \enc ++ cmp $7*8, %eax ++ jne 1b ++ _ghash_update_end_8x_step 0 ++ _ghash_update_end_8x_step 1 ++ sub $-8*16, DST ++.endif ++ ++.Lcrypt_loop_8x_done\@: ++ ++ sub $-8*16, DATALEN ++ jz .Ldone\@ ++ ++ // Handle the remainder of length 1 <= DATALEN < 8*16 bytes. We keep ++ // things simple and keep the code size down by just going one block at ++ // a time, again taking advantage of hardware loop unrolling. Since ++ // there are enough key powers available for all remaining data, we do ++ // the GHASH multiplications unreduced, and only reduce at the very end. ++ ++ .set HI, TMP2 ++ .set H_POW, AESDATA0 ++ .set H_POW_XORED, AESDATA1 ++ .set ONE, AESDATA2 ++ ++ movq .Lone(%rip), ONE ++ ++ // Start collecting the unreduced GHASH intermediate value LO, MI, HI. ++ pxor LO, LO ++ pxor MI, MI ++ pxor HI, HI ++ ++ // Set up a block counter %rax to contain 8*(8-n), where n is the number ++ // of blocks that remain, counting any partial block. This will be used ++ // to access the key powers H^n through H^1. ++ mov DATALEN, %eax ++ neg %eax ++ and $~15, %eax ++ sar $1, %eax ++ add $64, %eax ++ ++ sub $16, DATALEN ++ jl .Lcrypt_loop_1x_done\@ ++ ++ // Process the data one full block at a time. ++.Lcrypt_loop_1x\@: ++ ++ // Encrypt the next counter block. ++ _vpshufb BSWAP_MASK, LE_CTR, TMP0 ++ paddd ONE, LE_CTR ++ pxor (KEY), TMP0 ++ lea -6*16(RNDKEYLAST_PTR), %rsi // Reduce code size ++ cmp $24, AESKEYLEN ++ jl 128f // AES-128? ++ je 192f // AES-192? ++ // AES-256 ++ aesenc -7*16(%rsi), TMP0 ++ aesenc -6*16(%rsi), TMP0 ++192: ++ aesenc -5*16(%rsi), TMP0 ++ aesenc -4*16(%rsi), TMP0 ++128: ++.irp i, -3,-2,-1,0,1,2,3,4,5 ++ aesenc \i*16(%rsi), TMP0 ++.endr ++ aesenclast (RNDKEYLAST_PTR), TMP0 ++ ++ // Load the next key power H^i. ++ movdqa OFFSETOF_H_POWERS(KEY,%rax,2), H_POW ++ movq OFFSETOF_H_POWERS_XORED(KEY,%rax), H_POW_XORED ++ ++ // XOR the keystream block that was just generated in TMP0 with the next ++ // source data block and store the resulting en/decrypted data to DST. ++.if \enc ++ _xor_mem_to_reg (SRC), TMP0, tmp=TMP1 ++ movdqu TMP0, (DST) ++.else ++ movdqu (SRC), TMP1 ++ pxor TMP1, TMP0 ++ movdqu TMP0, (DST) ++.endif ++ ++ // Update GHASH with the ciphertext block. ++.if \enc ++ pshufb BSWAP_MASK, TMP0 ++ pxor TMP0, GHASH_ACC ++.else ++ pshufb BSWAP_MASK, TMP1 ++ pxor TMP1, GHASH_ACC ++.endif ++ _ghash_mul_noreduce H_POW, H_POW_XORED, GHASH_ACC, LO, MI, HI, TMP0 ++ pxor GHASH_ACC, GHASH_ACC ++ ++ add $8, %eax ++ add $16, SRC ++ add $16, DST ++ sub $16, DATALEN ++ jge .Lcrypt_loop_1x\@ ++.Lcrypt_loop_1x_done\@: ++ // Check whether there is a partial block at the end. ++ add $16, DATALEN ++ jz .Lghash_reduce\@ ++ ++ // Process a partial block of length 1 <= DATALEN <= 15. ++ ++ // Encrypt a counter block for the last time. ++ pshufb BSWAP_MASK, LE_CTR ++ pxor (KEY), LE_CTR ++ lea 16(KEY), %rsi ++1: ++ aesenc (%rsi), LE_CTR ++ add $16, %rsi ++ cmp %rsi, RNDKEYLAST_PTR ++ jne 1b ++ aesenclast (RNDKEYLAST_PTR), LE_CTR ++ ++ // Load the lowest key power, H^1. ++ movdqa OFFSETOF_H_POWERS(KEY,%rax,2), H_POW ++ movq OFFSETOF_H_POWERS_XORED(KEY,%rax), H_POW_XORED ++ ++ // Load and zero-pad 1 <= DATALEN <= 15 bytes of data from SRC. SRC is ++ // in %rcx, but _load_partial_block needs DATALEN in %rcx instead. ++ // RNDKEYLAST_PTR is no longer needed, so reuse it for SRC. ++ mov SRC, RNDKEYLAST_PTR ++ mov DATALEN, %ecx ++ _load_partial_block RNDKEYLAST_PTR, TMP0, %rsi, %esi ++ ++ // XOR the keystream block that was just generated in LE_CTR with the ++ // source data block and store the resulting en/decrypted data to DST. ++ pxor TMP0, LE_CTR ++ mov DATALEN, %ecx ++ _store_partial_block LE_CTR, DST ++ ++ // If encrypting, zero-pad the final ciphertext block for GHASH. (If ++ // decrypting, this was already done by _load_partial_block.) ++.if \enc ++ lea .Lzeropad_mask+16(%rip), %rax ++ sub DATALEN64, %rax ++ _vpand (%rax), LE_CTR, TMP0 ++.endif ++ ++ // Update GHASH with the final ciphertext block. ++ pshufb BSWAP_MASK, TMP0 ++ pxor TMP0, GHASH_ACC ++ _ghash_mul_noreduce H_POW, H_POW_XORED, GHASH_ACC, LO, MI, HI, TMP0 ++ ++.Lghash_reduce\@: ++ // Finally, do the GHASH reduction. ++ _ghash_reduce LO, MI, HI, GHASH_ACC, TMP0 ++ ++.Ldone\@: ++ // Store the updated GHASH accumulator back to memory. ++ movdqu GHASH_ACC, (GHASH_ACC_PTR) ++ ++ RET ++.endm ++ ++// void aes_gcm_enc_final_##suffix(const struct aes_gcm_key_aesni *key, ++// const u32 le_ctr[4], u8 ghash_acc[16], ++// u64 total_aadlen, u64 total_datalen); ++// bool aes_gcm_dec_final_##suffix(const struct aes_gcm_key_aesni *key, ++// const u32 le_ctr[4], const u8 ghash_acc[16], ++// u64 total_aadlen, u64 total_datalen, ++// const u8 tag[16], int taglen); ++// ++// This macro generates one of the above two functions (with \enc selecting ++// which one). Both functions finish computing the GCM authentication tag by ++// updating GHASH with the lengths block and encrypting the GHASH accumulator. ++// |total_aadlen| and |total_datalen| must be the total length of the additional ++// authenticated data and the en/decrypted data in bytes, respectively. ++// ++// The encryption function then stores the full-length (16-byte) computed ++// authentication tag to |ghash_acc|. The decryption function instead loads the ++// expected authentication tag (the one that was transmitted) from the 16-byte ++// buffer |tag|, compares the first 4 <= |taglen| <= 16 bytes of it to the ++// computed tag in constant time, and returns true if and only if they match. ++.macro _aes_gcm_final enc ++ ++ // Function arguments ++ .set KEY, %rdi ++ .set LE_CTR_PTR, %rsi ++ .set GHASH_ACC_PTR, %rdx ++ .set TOTAL_AADLEN, %rcx ++ .set TOTAL_DATALEN, %r8 ++ .set TAG, %r9 ++ .set TAGLEN, %r10d // Originally at 8(%rsp) ++ .set TAGLEN64, %r10 ++ ++ // Additional local variables. ++ // %rax and %xmm0-%xmm2 are used as temporary registers. ++ .set AESKEYLEN, %r11d ++ .set AESKEYLEN64, %r11 ++ .set BSWAP_MASK, %xmm3 ++ .set GHASH_ACC, %xmm4 ++ .set H_POW1, %xmm5 // H^1 ++ .set H_POW1_X64, %xmm6 // H^1 * x^64 ++ .set GFPOLY, %xmm7 ++ ++ movdqa .Lbswap_mask(%rip), BSWAP_MASK ++ movl OFFSETOF_AESKEYLEN(KEY), AESKEYLEN ++ ++ // Set up a counter block with 1 in the low 32-bit word. This is the ++ // counter that produces the ciphertext needed to encrypt the auth tag. ++ movdqu (LE_CTR_PTR), %xmm0 ++ mov $1, %eax ++ pinsrd $0, %eax, %xmm0 ++ ++ // Build the lengths block and XOR it into the GHASH accumulator. ++ movq TOTAL_DATALEN, GHASH_ACC ++ pinsrq $1, TOTAL_AADLEN, GHASH_ACC ++ psllq $3, GHASH_ACC // Bytes to bits ++ _xor_mem_to_reg (GHASH_ACC_PTR), GHASH_ACC, %xmm1 ++ ++ movdqa OFFSETOF_H_POWERS+7*16(KEY), H_POW1 ++ movdqa OFFSETOF_H_TIMES_X64(KEY), H_POW1_X64 ++ movq .Lgfpoly(%rip), GFPOLY ++ ++ // Make %rax point to the 6th from last AES round key. (Using signed ++ // byte offsets -7*16 through 6*16 decreases code size.) ++ lea (KEY,AESKEYLEN64,4), %rax ++ ++ // AES-encrypt the counter block and also multiply GHASH_ACC by H^1. ++ // Interleave the AES and GHASH instructions to improve performance. ++ pshufb BSWAP_MASK, %xmm0 ++ pxor (KEY), %xmm0 ++ cmp $24, AESKEYLEN ++ jl 128f // AES-128? ++ je 192f // AES-192? ++ // AES-256 ++ aesenc -7*16(%rax), %xmm0 ++ aesenc -6*16(%rax), %xmm0 ++192: ++ aesenc -5*16(%rax), %xmm0 ++ aesenc -4*16(%rax), %xmm0 ++128: ++.irp i, 0,1,2,3,4,5,6,7,8 ++ aesenc (\i-3)*16(%rax), %xmm0 ++ _ghash_mul_step \i, H_POW1, H_POW1_X64, GHASH_ACC, GFPOLY, %xmm1, %xmm2 ++.endr ++ aesenclast 6*16(%rax), %xmm0 ++ _ghash_mul_step 9, H_POW1, H_POW1_X64, GHASH_ACC, GFPOLY, %xmm1, %xmm2 ++ ++ // Undo the byte reflection of the GHASH accumulator. ++ pshufb BSWAP_MASK, GHASH_ACC ++ ++ // Encrypt the GHASH accumulator. ++ pxor %xmm0, GHASH_ACC ++ ++.if \enc ++ // Return the computed auth tag. ++ movdqu GHASH_ACC, (GHASH_ACC_PTR) ++.else ++ .set ZEROPAD_MASK_PTR, TOTAL_AADLEN // Reusing TOTAL_AADLEN! ++ ++ // Verify the auth tag in constant time by XOR'ing the transmitted and ++ // computed auth tags together and using the ptest instruction to check ++ // whether the first TAGLEN bytes of the result are zero. ++ _xor_mem_to_reg (TAG), GHASH_ACC, tmp=%xmm0 ++ movl 8(%rsp), TAGLEN ++ lea .Lzeropad_mask+16(%rip), ZEROPAD_MASK_PTR ++ sub TAGLEN64, ZEROPAD_MASK_PTR ++ xor %eax, %eax ++ _test_mem (ZEROPAD_MASK_PTR), GHASH_ACC, tmp=%xmm0 ++ sete %al ++.endif ++ RET ++.endm ++ ++.set USE_AVX, 0 ++SYM_FUNC_START(aes_gcm_precompute_aesni) ++ _aes_gcm_precompute ++SYM_FUNC_END(aes_gcm_precompute_aesni) ++SYM_FUNC_START(aes_gcm_aad_update_aesni) ++ _aes_gcm_aad_update ++SYM_FUNC_END(aes_gcm_aad_update_aesni) ++SYM_FUNC_START(aes_gcm_enc_update_aesni) ++ _aes_gcm_update 1 ++SYM_FUNC_END(aes_gcm_enc_update_aesni) ++SYM_FUNC_START(aes_gcm_dec_update_aesni) ++ _aes_gcm_update 0 ++SYM_FUNC_END(aes_gcm_dec_update_aesni) ++SYM_FUNC_START(aes_gcm_enc_final_aesni) ++ _aes_gcm_final 1 ++SYM_FUNC_END(aes_gcm_enc_final_aesni) ++SYM_FUNC_START(aes_gcm_dec_final_aesni) ++ _aes_gcm_final 0 ++SYM_FUNC_END(aes_gcm_dec_final_aesni) ++ ++.set USE_AVX, 1 ++SYM_FUNC_START(aes_gcm_precompute_aesni_avx) ++ _aes_gcm_precompute ++SYM_FUNC_END(aes_gcm_precompute_aesni_avx) ++SYM_FUNC_START(aes_gcm_aad_update_aesni_avx) ++ _aes_gcm_aad_update ++SYM_FUNC_END(aes_gcm_aad_update_aesni_avx) ++SYM_FUNC_START(aes_gcm_enc_update_aesni_avx) ++ _aes_gcm_update 1 ++SYM_FUNC_END(aes_gcm_enc_update_aesni_avx) ++SYM_FUNC_START(aes_gcm_dec_update_aesni_avx) ++ _aes_gcm_update 0 ++SYM_FUNC_END(aes_gcm_dec_update_aesni_avx) ++SYM_FUNC_START(aes_gcm_enc_final_aesni_avx) ++ _aes_gcm_final 1 ++SYM_FUNC_END(aes_gcm_enc_final_aesni_avx) ++SYM_FUNC_START(aes_gcm_dec_final_aesni_avx) ++ _aes_gcm_final 0 ++SYM_FUNC_END(aes_gcm_dec_final_aesni_avx) +diff --git a/arch/x86/crypto/aes-gcm-avx10-x86_64.S b/arch/x86/crypto/aes-gcm-avx10-x86_64.S +new file mode 100644 +index 000000000000..97e0ee515fc5 +--- /dev/null ++++ b/arch/x86/crypto/aes-gcm-avx10-x86_64.S +@@ -0,0 +1,1222 @@ ++/* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */ ++// ++// VAES and VPCLMULQDQ optimized AES-GCM for x86_64 ++// ++// Copyright 2024 Google LLC ++// ++// Author: Eric Biggers ++// ++//------------------------------------------------------------------------------ ++// ++// This file is dual-licensed, meaning that you can use it under your choice of ++// either of the following two licenses: ++// ++// Licensed under the Apache License 2.0 (the "License"). You may obtain a copy ++// of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// or ++// ++// Redistribution and use in source and binary forms, with or without ++// modification, are permitted provided that the following conditions are met: ++// ++// 1. Redistributions of source code must retain the above copyright notice, ++// this list of conditions and the following disclaimer. ++// ++// 2. Redistributions in binary form must reproduce the above copyright ++// notice, this list of conditions and the following disclaimer in the ++// documentation and/or other materials provided with the distribution. ++// ++// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" ++// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ++// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ++// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE ++// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR ++// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF ++// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS ++// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN ++// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ++// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ++// POSSIBILITY OF SUCH DAMAGE. ++// ++//------------------------------------------------------------------------------ ++// ++// This file implements AES-GCM (Galois/Counter Mode) for x86_64 CPUs that ++// support VAES (vector AES), VPCLMULQDQ (vector carryless multiplication), and ++// either AVX512 or AVX10. Some of the functions, notably the encryption and ++// decryption update functions which are the most performance-critical, are ++// provided in two variants generated from a macro: one using 256-bit vectors ++// (suffix: vaes_avx10_256) and one using 512-bit vectors (vaes_avx10_512). The ++// other, "shared" functions (vaes_avx10) use at most 256-bit vectors. ++// ++// The functions that use 512-bit vectors are intended for CPUs that support ++// 512-bit vectors *and* where using them doesn't cause significant ++// downclocking. They require the following CPU features: ++// ++// VAES && VPCLMULQDQ && BMI2 && ((AVX512BW && AVX512VL) || AVX10/512) ++// ++// The other functions require the following CPU features: ++// ++// VAES && VPCLMULQDQ && BMI2 && ((AVX512BW && AVX512VL) || AVX10/256) ++// ++// All functions use the "System V" ABI. The Windows ABI is not supported. ++// ++// Note that we use "avx10" in the names of the functions as a shorthand to ++// really mean "AVX10 or a certain set of AVX512 features". Due to Intel's ++// introduction of AVX512 and then its replacement by AVX10, there doesn't seem ++// to be a simple way to name things that makes sense on all CPUs. ++// ++// Note that the macros that support both 256-bit and 512-bit vectors could ++// fairly easily be changed to support 128-bit too. However, this would *not* ++// be sufficient to allow the code to run on CPUs without AVX512 or AVX10, ++// because the code heavily uses several features of these extensions other than ++// the vector length: the increase in the number of SIMD registers from 16 to ++// 32, masking support, and new instructions such as vpternlogd (which can do a ++// three-argument XOR). These features are very useful for AES-GCM. ++ ++#include ++ ++.section .rodata ++.p2align 6 ++ ++ // A shuffle mask that reflects the bytes of 16-byte blocks ++.Lbswap_mask: ++ .octa 0x000102030405060708090a0b0c0d0e0f ++ ++ // This is the GHASH reducing polynomial without its constant term, i.e. ++ // x^128 + x^7 + x^2 + x, represented using the backwards mapping ++ // between bits and polynomial coefficients. ++ // ++ // Alternatively, it can be interpreted as the naturally-ordered ++ // representation of the polynomial x^127 + x^126 + x^121 + 1, i.e. the ++ // "reversed" GHASH reducing polynomial without its x^128 term. ++.Lgfpoly: ++ .octa 0xc2000000000000000000000000000001 ++ ++ // Same as above, but with the (1 << 64) bit set. ++.Lgfpoly_and_internal_carrybit: ++ .octa 0xc2000000000000010000000000000001 ++ ++ // The below constants are used for incrementing the counter blocks. ++ // ctr_pattern points to the four 128-bit values [0, 1, 2, 3]. ++ // inc_2blocks and inc_4blocks point to the single 128-bit values 2 and ++ // 4. Note that the same '2' is reused in ctr_pattern and inc_2blocks. ++.Lctr_pattern: ++ .octa 0 ++ .octa 1 ++.Linc_2blocks: ++ .octa 2 ++ .octa 3 ++.Linc_4blocks: ++ .octa 4 ++ ++// Number of powers of the hash key stored in the key struct. The powers are ++// stored from highest (H^NUM_H_POWERS) to lowest (H^1). ++#define NUM_H_POWERS 16 ++ ++// Offset to AES key length (in bytes) in the key struct ++#define OFFSETOF_AESKEYLEN 480 ++ ++// Offset to start of hash key powers array in the key struct ++#define OFFSETOF_H_POWERS 512 ++ ++// Offset to end of hash key powers array in the key struct. ++// ++// This is immediately followed by three zeroized padding blocks, which are ++// included so that partial vectors can be handled more easily. E.g. if VL=64 ++// and two blocks remain, we load the 4 values [H^2, H^1, 0, 0]. The most ++// padding blocks needed is 3, which occurs if [H^1, 0, 0, 0] is loaded. ++#define OFFSETOFEND_H_POWERS (OFFSETOF_H_POWERS + (NUM_H_POWERS * 16)) ++ ++.text ++ ++// Set the vector length in bytes. This sets the VL variable and defines ++// register aliases V0-V31 that map to the ymm or zmm registers. ++.macro _set_veclen vl ++ .set VL, \vl ++.irp i, 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15, \ ++ 16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 ++.if VL == 32 ++ .set V\i, %ymm\i ++.elseif VL == 64 ++ .set V\i, %zmm\i ++.else ++ .error "Unsupported vector length" ++.endif ++.endr ++.endm ++ ++// The _ghash_mul_step macro does one step of GHASH multiplication of the ++// 128-bit lanes of \a by the corresponding 128-bit lanes of \b and storing the ++// reduced products in \dst. \t0, \t1, and \t2 are temporary registers of the ++// same size as \a and \b. To complete all steps, this must invoked with \i=0 ++// through \i=9. The division into steps allows users of this macro to ++// optionally interleave the computation with other instructions. Users of this ++// macro must preserve the parameter registers across steps. ++// ++// The multiplications are done in GHASH's representation of the finite field ++// GF(2^128). Elements of GF(2^128) are represented as binary polynomials ++// (i.e. polynomials whose coefficients are bits) modulo a reducing polynomial ++// G. The GCM specification uses G = x^128 + x^7 + x^2 + x + 1. Addition is ++// just XOR, while multiplication is more complex and has two parts: (a) do ++// carryless multiplication of two 128-bit input polynomials to get a 256-bit ++// intermediate product polynomial, and (b) reduce the intermediate product to ++// 128 bits by adding multiples of G that cancel out terms in it. (Adding ++// multiples of G doesn't change which field element the polynomial represents.) ++// ++// Unfortunately, the GCM specification maps bits to/from polynomial ++// coefficients backwards from the natural order. In each byte it specifies the ++// highest bit to be the lowest order polynomial coefficient, *not* the highest! ++// This makes it nontrivial to work with the GHASH polynomials. We could ++// reflect the bits, but x86 doesn't have an instruction that does that. ++// ++// Instead, we operate on the values without bit-reflecting them. This *mostly* ++// just works, since XOR and carryless multiplication are symmetric with respect ++// to bit order, but it has some consequences. First, due to GHASH's byte ++// order, by skipping bit reflection, *byte* reflection becomes necessary to ++// give the polynomial terms a consistent order. E.g., considering an N-bit ++// value interpreted using the G = x^128 + x^7 + x^2 + x + 1 convention, bits 0 ++// through N-1 of the byte-reflected value represent the coefficients of x^(N-1) ++// through x^0, whereas bits 0 through N-1 of the non-byte-reflected value ++// represent x^7...x^0, x^15...x^8, ..., x^(N-1)...x^(N-8) which can't be worked ++// with. Fortunately, x86's vpshufb instruction can do byte reflection. ++// ++// Second, forgoing the bit reflection causes an extra multiple of x (still ++// using the G = x^128 + x^7 + x^2 + x + 1 convention) to be introduced by each ++// multiplication. This is because an M-bit by N-bit carryless multiplication ++// really produces a (M+N-1)-bit product, but in practice it's zero-extended to ++// M+N bits. In the G = x^128 + x^7 + x^2 + x + 1 convention, which maps bits ++// to polynomial coefficients backwards, this zero-extension actually changes ++// the product by introducing an extra factor of x. Therefore, users of this ++// macro must ensure that one of the inputs has an extra factor of x^-1, i.e. ++// the multiplicative inverse of x, to cancel out the extra x. ++// ++// Third, the backwards coefficients convention is just confusing to work with, ++// since it makes "low" and "high" in the polynomial math mean the opposite of ++// their normal meaning in computer programming. This can be solved by using an ++// alternative interpretation: the polynomial coefficients are understood to be ++// in the natural order, and the multiplication is actually \a * \b * x^-128 mod ++// x^128 + x^127 + x^126 + x^121 + 1. This doesn't change the inputs, outputs, ++// or the implementation at all; it just changes the mathematical interpretation ++// of what each instruction is doing. Starting from here, we'll use this ++// alternative interpretation, as it's easier to understand the code that way. ++// ++// Moving onto the implementation, the vpclmulqdq instruction does 64 x 64 => ++// 128-bit carryless multiplication, so we break the 128 x 128 multiplication ++// into parts as follows (the _L and _H suffixes denote low and high 64 bits): ++// ++// LO = a_L * b_L ++// MI = (a_L * b_H) + (a_H * b_L) ++// HI = a_H * b_H ++// ++// The 256-bit product is x^128*HI + x^64*MI + LO. LO, MI, and HI are 128-bit. ++// Note that MI "overlaps" with LO and HI. We don't consolidate MI into LO and ++// HI right away, since the way the reduction works makes that unnecessary. ++// ++// For the reduction, we cancel out the low 128 bits by adding multiples of G = ++// x^128 + x^127 + x^126 + x^121 + 1. This is done by two iterations, each of ++// which cancels out the next lowest 64 bits. Consider a value x^64*A + B, ++// where A and B are 128-bit. Adding B_L*G to that value gives: ++// ++// x^64*A + B + B_L*G ++// = x^64*A + x^64*B_H + B_L + B_L*(x^128 + x^127 + x^126 + x^121 + 1) ++// = x^64*A + x^64*B_H + B_L + x^128*B_L + x^64*B_L*(x^63 + x^62 + x^57) + B_L ++// = x^64*A + x^64*B_H + x^128*B_L + x^64*B_L*(x^63 + x^62 + x^57) + B_L + B_L ++// = x^64*(A + B_H + x^64*B_L + B_L*(x^63 + x^62 + x^57)) ++// ++// So: if we sum A, B with its halves swapped, and the low half of B times x^63 ++// + x^62 + x^57, we get a 128-bit value C where x^64*C is congruent to the ++// original value x^64*A + B. I.e., the low 64 bits got canceled out. ++// ++// We just need to apply this twice: first to fold LO into MI, and second to ++// fold the updated MI into HI. ++// ++// The needed three-argument XORs are done using the vpternlogd instruction with ++// immediate 0x96, since this is faster than two vpxord instructions. ++// ++// A potential optimization, assuming that b is fixed per-key (if a is fixed ++// per-key it would work the other way around), is to use one iteration of the ++// reduction described above to precompute a value c such that x^64*c = b mod G, ++// and then multiply a_L by c (and implicitly by x^64) instead of by b: ++// ++// MI = (a_L * c_L) + (a_H * b_L) ++// HI = (a_L * c_H) + (a_H * b_H) ++// ++// This would eliminate the LO part of the intermediate product, which would ++// eliminate the need to fold LO into MI. This would save two instructions, ++// including a vpclmulqdq. However, we currently don't use this optimization ++// because it would require twice as many per-key precomputed values. ++// ++// Using Karatsuba multiplication instead of "schoolbook" multiplication ++// similarly would save a vpclmulqdq but does not seem to be worth it. ++.macro _ghash_mul_step i, a, b, dst, gfpoly, t0, t1, t2 ++.if \i == 0 ++ vpclmulqdq $0x00, \a, \b, \t0 // LO = a_L * b_L ++ vpclmulqdq $0x01, \a, \b, \t1 // MI_0 = a_L * b_H ++.elseif \i == 1 ++ vpclmulqdq $0x10, \a, \b, \t2 // MI_1 = a_H * b_L ++.elseif \i == 2 ++ vpxord \t2, \t1, \t1 // MI = MI_0 + MI_1 ++.elseif \i == 3 ++ vpclmulqdq $0x01, \t0, \gfpoly, \t2 // LO_L*(x^63 + x^62 + x^57) ++.elseif \i == 4 ++ vpshufd $0x4e, \t0, \t0 // Swap halves of LO ++.elseif \i == 5 ++ vpternlogd $0x96, \t2, \t0, \t1 // Fold LO into MI ++.elseif \i == 6 ++ vpclmulqdq $0x11, \a, \b, \dst // HI = a_H * b_H ++.elseif \i == 7 ++ vpclmulqdq $0x01, \t1, \gfpoly, \t0 // MI_L*(x^63 + x^62 + x^57) ++.elseif \i == 8 ++ vpshufd $0x4e, \t1, \t1 // Swap halves of MI ++.elseif \i == 9 ++ vpternlogd $0x96, \t0, \t1, \dst // Fold MI into HI ++.endif ++.endm ++ ++// GHASH-multiply the 128-bit lanes of \a by the 128-bit lanes of \b and store ++// the reduced products in \dst. See _ghash_mul_step for full explanation. ++.macro _ghash_mul a, b, dst, gfpoly, t0, t1, t2 ++.irp i, 0,1,2,3,4,5,6,7,8,9 ++ _ghash_mul_step \i, \a, \b, \dst, \gfpoly, \t0, \t1, \t2 ++.endr ++.endm ++ ++// GHASH-multiply the 128-bit lanes of \a by the 128-bit lanes of \b and add the ++// *unreduced* products to \lo, \mi, and \hi. ++.macro _ghash_mul_noreduce a, b, lo, mi, hi, t0, t1, t2, t3 ++ vpclmulqdq $0x00, \a, \b, \t0 // a_L * b_L ++ vpclmulqdq $0x01, \a, \b, \t1 // a_L * b_H ++ vpclmulqdq $0x10, \a, \b, \t2 // a_H * b_L ++ vpclmulqdq $0x11, \a, \b, \t3 // a_H * b_H ++ vpxord \t0, \lo, \lo ++ vpternlogd $0x96, \t2, \t1, \mi ++ vpxord \t3, \hi, \hi ++.endm ++ ++// Reduce the unreduced products from \lo, \mi, and \hi and store the 128-bit ++// reduced products in \hi. See _ghash_mul_step for explanation of reduction. ++.macro _ghash_reduce lo, mi, hi, gfpoly, t0 ++ vpclmulqdq $0x01, \lo, \gfpoly, \t0 ++ vpshufd $0x4e, \lo, \lo ++ vpternlogd $0x96, \t0, \lo, \mi ++ vpclmulqdq $0x01, \mi, \gfpoly, \t0 ++ vpshufd $0x4e, \mi, \mi ++ vpternlogd $0x96, \t0, \mi, \hi ++.endm ++ ++// void aes_gcm_precompute_##suffix(struct aes_gcm_key_avx10 *key); ++// ++// Given the expanded AES key |key->aes_key|, this function derives the GHASH ++// subkey and initializes |key->ghash_key_powers| with powers of it. ++// ++// The number of key powers initialized is NUM_H_POWERS, and they are stored in ++// the order H^NUM_H_POWERS to H^1. The zeroized padding blocks after the key ++// powers themselves are also initialized. ++// ++// This macro supports both VL=32 and VL=64. _set_veclen must have been invoked ++// with the desired length. In the VL=32 case, the function computes twice as ++// many key powers than are actually used by the VL=32 GCM update functions. ++// This is done to keep the key format the same regardless of vector length. ++.macro _aes_gcm_precompute ++ ++ // Function arguments ++ .set KEY, %rdi ++ ++ // Additional local variables. V0-V2 and %rax are used as temporaries. ++ .set POWERS_PTR, %rsi ++ .set RNDKEYLAST_PTR, %rdx ++ .set H_CUR, V3 ++ .set H_CUR_YMM, %ymm3 ++ .set H_CUR_XMM, %xmm3 ++ .set H_INC, V4 ++ .set H_INC_YMM, %ymm4 ++ .set H_INC_XMM, %xmm4 ++ .set GFPOLY, V5 ++ .set GFPOLY_YMM, %ymm5 ++ .set GFPOLY_XMM, %xmm5 ++ ++ // Get pointer to lowest set of key powers (located at end of array). ++ lea OFFSETOFEND_H_POWERS-VL(KEY), POWERS_PTR ++ ++ // Encrypt an all-zeroes block to get the raw hash subkey. ++ movl OFFSETOF_AESKEYLEN(KEY), %eax ++ lea 6*16(KEY,%rax,4), RNDKEYLAST_PTR ++ vmovdqu (KEY), %xmm0 // Zero-th round key XOR all-zeroes block ++ add $16, KEY ++1: ++ vaesenc (KEY), %xmm0, %xmm0 ++ add $16, KEY ++ cmp KEY, RNDKEYLAST_PTR ++ jne 1b ++ vaesenclast (RNDKEYLAST_PTR), %xmm0, %xmm0 ++ ++ // Reflect the bytes of the raw hash subkey. ++ vpshufb .Lbswap_mask(%rip), %xmm0, H_CUR_XMM ++ ++ // Zeroize the padding blocks. ++ vpxor %xmm0, %xmm0, %xmm0 ++ vmovdqu %ymm0, VL(POWERS_PTR) ++ vmovdqu %xmm0, VL+2*16(POWERS_PTR) ++ ++ // Finish preprocessing the first key power, H^1. Since this GHASH ++ // implementation operates directly on values with the backwards bit ++ // order specified by the GCM standard, it's necessary to preprocess the ++ // raw key as follows. First, reflect its bytes. Second, multiply it ++ // by x^-1 mod x^128 + x^7 + x^2 + x + 1 (if using the backwards ++ // interpretation of polynomial coefficients), which can also be ++ // interpreted as multiplication by x mod x^128 + x^127 + x^126 + x^121 ++ // + 1 using the alternative, natural interpretation of polynomial ++ // coefficients. For details, see the comment above _ghash_mul_step. ++ // ++ // Either way, for the multiplication the concrete operation performed ++ // is a left shift of the 128-bit value by 1 bit, then an XOR with (0xc2 ++ // << 120) | 1 if a 1 bit was carried out. However, there's no 128-bit ++ // wide shift instruction, so instead double each of the two 64-bit ++ // halves and incorporate the internal carry bit into the value XOR'd. ++ vpshufd $0xd3, H_CUR_XMM, %xmm0 ++ vpsrad $31, %xmm0, %xmm0 ++ vpaddq H_CUR_XMM, H_CUR_XMM, H_CUR_XMM ++ vpand .Lgfpoly_and_internal_carrybit(%rip), %xmm0, %xmm0 ++ vpxor %xmm0, H_CUR_XMM, H_CUR_XMM ++ ++ // Load the gfpoly constant. ++ vbroadcasti32x4 .Lgfpoly(%rip), GFPOLY ++ ++ // Square H^1 to get H^2. ++ // ++ // Note that as with H^1, all higher key powers also need an extra ++ // factor of x^-1 (or x using the natural interpretation). Nothing ++ // special needs to be done to make this happen, though: H^1 * H^1 would ++ // end up with two factors of x^-1, but the multiplication consumes one. ++ // So the product H^2 ends up with the desired one factor of x^-1. ++ _ghash_mul H_CUR_XMM, H_CUR_XMM, H_INC_XMM, GFPOLY_XMM, \ ++ %xmm0, %xmm1, %xmm2 ++ ++ // Create H_CUR_YMM = [H^2, H^1] and H_INC_YMM = [H^2, H^2]. ++ vinserti128 $1, H_CUR_XMM, H_INC_YMM, H_CUR_YMM ++ vinserti128 $1, H_INC_XMM, H_INC_YMM, H_INC_YMM ++ ++.if VL == 64 ++ // Create H_CUR = [H^4, H^3, H^2, H^1] and H_INC = [H^4, H^4, H^4, H^4]. ++ _ghash_mul H_INC_YMM, H_CUR_YMM, H_INC_YMM, GFPOLY_YMM, \ ++ %ymm0, %ymm1, %ymm2 ++ vinserti64x4 $1, H_CUR_YMM, H_INC, H_CUR ++ vshufi64x2 $0, H_INC, H_INC, H_INC ++.endif ++ ++ // Store the lowest set of key powers. ++ vmovdqu8 H_CUR, (POWERS_PTR) ++ ++ // Compute and store the remaining key powers. With VL=32, repeatedly ++ // multiply [H^(i+1), H^i] by [H^2, H^2] to get [H^(i+3), H^(i+2)]. ++ // With VL=64, repeatedly multiply [H^(i+3), H^(i+2), H^(i+1), H^i] by ++ // [H^4, H^4, H^4, H^4] to get [H^(i+7), H^(i+6), H^(i+5), H^(i+4)]. ++ mov $(NUM_H_POWERS*16/VL) - 1, %eax ++.Lprecompute_next\@: ++ sub $VL, POWERS_PTR ++ _ghash_mul H_INC, H_CUR, H_CUR, GFPOLY, V0, V1, V2 ++ vmovdqu8 H_CUR, (POWERS_PTR) ++ dec %eax ++ jnz .Lprecompute_next\@ ++ ++ vzeroupper // This is needed after using ymm or zmm registers. ++ RET ++.endm ++ ++// XOR together the 128-bit lanes of \src (whose low lane is \src_xmm) and store ++// the result in \dst_xmm. This implicitly zeroizes the other lanes of dst. ++.macro _horizontal_xor src, src_xmm, dst_xmm, t0_xmm, t1_xmm, t2_xmm ++ vextracti32x4 $1, \src, \t0_xmm ++.if VL == 32 ++ vpxord \t0_xmm, \src_xmm, \dst_xmm ++.elseif VL == 64 ++ vextracti32x4 $2, \src, \t1_xmm ++ vextracti32x4 $3, \src, \t2_xmm ++ vpxord \t0_xmm, \src_xmm, \dst_xmm ++ vpternlogd $0x96, \t1_xmm, \t2_xmm, \dst_xmm ++.else ++ .error "Unsupported vector length" ++.endif ++.endm ++ ++// Do one step of the GHASH update of the data blocks given in the vector ++// registers GHASHDATA[0-3]. \i specifies the step to do, 0 through 9. The ++// division into steps allows users of this macro to optionally interleave the ++// computation with other instructions. This macro uses the vector register ++// GHASH_ACC as input/output; GHASHDATA[0-3] as inputs that are clobbered; ++// H_POW[4-1], GFPOLY, and BSWAP_MASK as inputs that aren't clobbered; and ++// GHASHTMP[0-2] as temporaries. This macro handles the byte-reflection of the ++// data blocks. The parameter registers must be preserved across steps. ++// ++// The GHASH update does: GHASH_ACC = H_POW4*(GHASHDATA0 + GHASH_ACC) + ++// H_POW3*GHASHDATA1 + H_POW2*GHASHDATA2 + H_POW1*GHASHDATA3, where the ++// operations are vectorized operations on vectors of 16-byte blocks. E.g., ++// with VL=32 there are 2 blocks per vector and the vectorized terms correspond ++// to the following non-vectorized terms: ++// ++// H_POW4*(GHASHDATA0 + GHASH_ACC) => H^8*(blk0 + GHASH_ACC_XMM) and H^7*(blk1 + 0) ++// H_POW3*GHASHDATA1 => H^6*blk2 and H^5*blk3 ++// H_POW2*GHASHDATA2 => H^4*blk4 and H^3*blk5 ++// H_POW1*GHASHDATA3 => H^2*blk6 and H^1*blk7 ++// ++// With VL=64, we use 4 blocks/vector, H^16 through H^1, and blk0 through blk15. ++// ++// More concretely, this code does: ++// - Do vectorized "schoolbook" multiplications to compute the intermediate ++// 256-bit product of each block and its corresponding hash key power. ++// There are 4*VL/16 of these intermediate products. ++// - Sum (XOR) the intermediate 256-bit products across vectors. This leaves ++// VL/16 256-bit intermediate values. ++// - Do a vectorized reduction of these 256-bit intermediate values to ++// 128-bits each. This leaves VL/16 128-bit intermediate values. ++// - Sum (XOR) these values and store the 128-bit result in GHASH_ACC_XMM. ++// ++// See _ghash_mul_step for the full explanation of the operations performed for ++// each individual finite field multiplication and reduction. ++.macro _ghash_step_4x i ++.if \i == 0 ++ vpshufb BSWAP_MASK, GHASHDATA0, GHASHDATA0 ++ vpxord GHASH_ACC, GHASHDATA0, GHASHDATA0 ++ vpshufb BSWAP_MASK, GHASHDATA1, GHASHDATA1 ++ vpshufb BSWAP_MASK, GHASHDATA2, GHASHDATA2 ++.elseif \i == 1 ++ vpshufb BSWAP_MASK, GHASHDATA3, GHASHDATA3 ++ vpclmulqdq $0x00, H_POW4, GHASHDATA0, GHASH_ACC // LO_0 ++ vpclmulqdq $0x00, H_POW3, GHASHDATA1, GHASHTMP0 // LO_1 ++ vpclmulqdq $0x00, H_POW2, GHASHDATA2, GHASHTMP1 // LO_2 ++.elseif \i == 2 ++ vpxord GHASHTMP0, GHASH_ACC, GHASH_ACC // sum(LO_{1,0}) ++ vpclmulqdq $0x00, H_POW1, GHASHDATA3, GHASHTMP2 // LO_3 ++ vpternlogd $0x96, GHASHTMP2, GHASHTMP1, GHASH_ACC // LO = sum(LO_{3,2,1,0}) ++ vpclmulqdq $0x01, H_POW4, GHASHDATA0, GHASHTMP0 // MI_0 ++.elseif \i == 3 ++ vpclmulqdq $0x01, H_POW3, GHASHDATA1, GHASHTMP1 // MI_1 ++ vpclmulqdq $0x01, H_POW2, GHASHDATA2, GHASHTMP2 // MI_2 ++ vpternlogd $0x96, GHASHTMP2, GHASHTMP1, GHASHTMP0 // sum(MI_{2,1,0}) ++ vpclmulqdq $0x01, H_POW1, GHASHDATA3, GHASHTMP1 // MI_3 ++.elseif \i == 4 ++ vpclmulqdq $0x10, H_POW4, GHASHDATA0, GHASHTMP2 // MI_4 ++ vpternlogd $0x96, GHASHTMP2, GHASHTMP1, GHASHTMP0 // sum(MI_{4,3,2,1,0}) ++ vpclmulqdq $0x10, H_POW3, GHASHDATA1, GHASHTMP1 // MI_5 ++ vpclmulqdq $0x10, H_POW2, GHASHDATA2, GHASHTMP2 // MI_6 ++.elseif \i == 5 ++ vpternlogd $0x96, GHASHTMP2, GHASHTMP1, GHASHTMP0 // sum(MI_{6,5,4,3,2,1,0}) ++ vpclmulqdq $0x01, GHASH_ACC, GFPOLY, GHASHTMP2 // LO_L*(x^63 + x^62 + x^57) ++ vpclmulqdq $0x10, H_POW1, GHASHDATA3, GHASHTMP1 // MI_7 ++ vpxord GHASHTMP1, GHASHTMP0, GHASHTMP0 // MI = sum(MI_{7,6,5,4,3,2,1,0}) ++.elseif \i == 6 ++ vpshufd $0x4e, GHASH_ACC, GHASH_ACC // Swap halves of LO ++ vpclmulqdq $0x11, H_POW4, GHASHDATA0, GHASHDATA0 // HI_0 ++ vpclmulqdq $0x11, H_POW3, GHASHDATA1, GHASHDATA1 // HI_1 ++ vpclmulqdq $0x11, H_POW2, GHASHDATA2, GHASHDATA2 // HI_2 ++.elseif \i == 7 ++ vpternlogd $0x96, GHASHTMP2, GHASH_ACC, GHASHTMP0 // Fold LO into MI ++ vpclmulqdq $0x11, H_POW1, GHASHDATA3, GHASHDATA3 // HI_3 ++ vpternlogd $0x96, GHASHDATA2, GHASHDATA1, GHASHDATA0 // sum(HI_{2,1,0}) ++ vpclmulqdq $0x01, GHASHTMP0, GFPOLY, GHASHTMP1 // MI_L*(x^63 + x^62 + x^57) ++.elseif \i == 8 ++ vpxord GHASHDATA3, GHASHDATA0, GHASH_ACC // HI = sum(HI_{3,2,1,0}) ++ vpshufd $0x4e, GHASHTMP0, GHASHTMP0 // Swap halves of MI ++ vpternlogd $0x96, GHASHTMP1, GHASHTMP0, GHASH_ACC // Fold MI into HI ++.elseif \i == 9 ++ _horizontal_xor GHASH_ACC, GHASH_ACC_XMM, GHASH_ACC_XMM, \ ++ GHASHDATA0_XMM, GHASHDATA1_XMM, GHASHDATA2_XMM ++.endif ++.endm ++ ++// Do one non-last round of AES encryption on the counter blocks in V0-V3 using ++// the round key that has been broadcast to all 128-bit lanes of \round_key. ++.macro _vaesenc_4x round_key ++ vaesenc \round_key, V0, V0 ++ vaesenc \round_key, V1, V1 ++ vaesenc \round_key, V2, V2 ++ vaesenc \round_key, V3, V3 ++.endm ++ ++// Start the AES encryption of four vectors of counter blocks. ++.macro _ctr_begin_4x ++ ++ // Increment LE_CTR four times to generate four vectors of little-endian ++ // counter blocks, swap each to big-endian, and store them in V0-V3. ++ vpshufb BSWAP_MASK, LE_CTR, V0 ++ vpaddd LE_CTR_INC, LE_CTR, LE_CTR ++ vpshufb BSWAP_MASK, LE_CTR, V1 ++ vpaddd LE_CTR_INC, LE_CTR, LE_CTR ++ vpshufb BSWAP_MASK, LE_CTR, V2 ++ vpaddd LE_CTR_INC, LE_CTR, LE_CTR ++ vpshufb BSWAP_MASK, LE_CTR, V3 ++ vpaddd LE_CTR_INC, LE_CTR, LE_CTR ++ ++ // AES "round zero": XOR in the zero-th round key. ++ vpxord RNDKEY0, V0, V0 ++ vpxord RNDKEY0, V1, V1 ++ vpxord RNDKEY0, V2, V2 ++ vpxord RNDKEY0, V3, V3 ++.endm ++ ++// void aes_gcm_{enc,dec}_update_##suffix(const struct aes_gcm_key_avx10 *key, ++// const u32 le_ctr[4], u8 ghash_acc[16], ++// const u8 *src, u8 *dst, int datalen); ++// ++// This macro generates a GCM encryption or decryption update function with the ++// above prototype (with \enc selecting which one). This macro supports both ++// VL=32 and VL=64. _set_veclen must have been invoked with the desired length. ++// ++// This function computes the next portion of the CTR keystream, XOR's it with ++// |datalen| bytes from |src|, and writes the resulting encrypted or decrypted ++// data to |dst|. It also updates the GHASH accumulator |ghash_acc| using the ++// next |datalen| ciphertext bytes. ++// ++// |datalen| must be a multiple of 16, except on the last call where it can be ++// any length. The caller must do any buffering needed to ensure this. Both ++// in-place and out-of-place en/decryption are supported. ++// ++// |le_ctr| must give the current counter in little-endian format. For a new ++// message, the low word of the counter must be 2. This function loads the ++// counter from |le_ctr| and increments the loaded counter as needed, but it ++// does *not* store the updated counter back to |le_ctr|. The caller must ++// update |le_ctr| if any more data segments follow. Internally, only the low ++// 32-bit word of the counter is incremented, following the GCM standard. ++.macro _aes_gcm_update enc ++ ++ // Function arguments ++ .set KEY, %rdi ++ .set LE_CTR_PTR, %rsi ++ .set GHASH_ACC_PTR, %rdx ++ .set SRC, %rcx ++ .set DST, %r8 ++ .set DATALEN, %r9d ++ .set DATALEN64, %r9 // Zero-extend DATALEN before using! ++ ++ // Additional local variables ++ ++ // %rax and %k1 are used as temporary registers. LE_CTR_PTR is also ++ // available as a temporary register after the counter is loaded. ++ ++ // AES key length in bytes ++ .set AESKEYLEN, %r10d ++ .set AESKEYLEN64, %r10 ++ ++ // Pointer to the last AES round key for the chosen AES variant ++ .set RNDKEYLAST_PTR, %r11 ++ ++ // In the main loop, V0-V3 are used as AES input and output. Elsewhere ++ // they are used as temporary registers. ++ ++ // GHASHDATA[0-3] hold the ciphertext blocks and GHASH input data. ++ .set GHASHDATA0, V4 ++ .set GHASHDATA0_XMM, %xmm4 ++ .set GHASHDATA1, V5 ++ .set GHASHDATA1_XMM, %xmm5 ++ .set GHASHDATA2, V6 ++ .set GHASHDATA2_XMM, %xmm6 ++ .set GHASHDATA3, V7 ++ ++ // BSWAP_MASK is the shuffle mask for byte-reflecting 128-bit values ++ // using vpshufb, copied to all 128-bit lanes. ++ .set BSWAP_MASK, V8 ++ ++ // RNDKEY temporarily holds the next AES round key. ++ .set RNDKEY, V9 ++ ++ // GHASH_ACC is the accumulator variable for GHASH. When fully reduced, ++ // only the lowest 128-bit lane can be nonzero. When not fully reduced, ++ // more than one lane may be used, and they need to be XOR'd together. ++ .set GHASH_ACC, V10 ++ .set GHASH_ACC_XMM, %xmm10 ++ ++ // LE_CTR_INC is the vector of 32-bit words that need to be added to a ++ // vector of little-endian counter blocks to advance it forwards. ++ .set LE_CTR_INC, V11 ++ ++ // LE_CTR contains the next set of little-endian counter blocks. ++ .set LE_CTR, V12 ++ ++ // RNDKEY0, RNDKEYLAST, and RNDKEY_M[9-5] contain cached AES round keys, ++ // copied to all 128-bit lanes. RNDKEY0 is the zero-th round key, ++ // RNDKEYLAST the last, and RNDKEY_M\i the one \i-th from the last. ++ .set RNDKEY0, V13 ++ .set RNDKEYLAST, V14 ++ .set RNDKEY_M9, V15 ++ .set RNDKEY_M8, V16 ++ .set RNDKEY_M7, V17 ++ .set RNDKEY_M6, V18 ++ .set RNDKEY_M5, V19 ++ ++ // RNDKEYLAST[0-3] temporarily store the last AES round key XOR'd with ++ // the corresponding block of source data. This is useful because ++ // vaesenclast(key, a) ^ b == vaesenclast(key ^ b, a), and key ^ b can ++ // be computed in parallel with the AES rounds. ++ .set RNDKEYLAST0, V20 ++ .set RNDKEYLAST1, V21 ++ .set RNDKEYLAST2, V22 ++ .set RNDKEYLAST3, V23 ++ ++ // GHASHTMP[0-2] are temporary variables used by _ghash_step_4x. These ++ // cannot coincide with anything used for AES encryption, since for ++ // performance reasons GHASH and AES encryption are interleaved. ++ .set GHASHTMP0, V24 ++ .set GHASHTMP1, V25 ++ .set GHASHTMP2, V26 ++ ++ // H_POW[4-1] contain the powers of the hash key H^(4*VL/16)...H^1. The ++ // descending numbering reflects the order of the key powers. ++ .set H_POW4, V27 ++ .set H_POW3, V28 ++ .set H_POW2, V29 ++ .set H_POW1, V30 ++ ++ // GFPOLY contains the .Lgfpoly constant, copied to all 128-bit lanes. ++ .set GFPOLY, V31 ++ ++ // Load some constants. ++ vbroadcasti32x4 .Lbswap_mask(%rip), BSWAP_MASK ++ vbroadcasti32x4 .Lgfpoly(%rip), GFPOLY ++ ++ // Load the GHASH accumulator and the starting counter. ++ vmovdqu (GHASH_ACC_PTR), GHASH_ACC_XMM ++ vbroadcasti32x4 (LE_CTR_PTR), LE_CTR ++ ++ // Load the AES key length in bytes. ++ movl OFFSETOF_AESKEYLEN(KEY), AESKEYLEN ++ ++ // Make RNDKEYLAST_PTR point to the last AES round key. This is the ++ // round key with index 10, 12, or 14 for AES-128, AES-192, or AES-256 ++ // respectively. Then load the zero-th and last round keys. ++ lea 6*16(KEY,AESKEYLEN64,4), RNDKEYLAST_PTR ++ vbroadcasti32x4 (KEY), RNDKEY0 ++ vbroadcasti32x4 (RNDKEYLAST_PTR), RNDKEYLAST ++ ++ // Finish initializing LE_CTR by adding [0, 1, ...] to its low words. ++ vpaddd .Lctr_pattern(%rip), LE_CTR, LE_CTR ++ ++ // Initialize LE_CTR_INC to contain VL/16 in all 128-bit lanes. ++.if VL == 32 ++ vbroadcasti32x4 .Linc_2blocks(%rip), LE_CTR_INC ++.elseif VL == 64 ++ vbroadcasti32x4 .Linc_4blocks(%rip), LE_CTR_INC ++.else ++ .error "Unsupported vector length" ++.endif ++ ++ // If there are at least 4*VL bytes of data, then continue into the loop ++ // that processes 4*VL bytes of data at a time. Otherwise skip it. ++ // ++ // Pre-subtracting 4*VL from DATALEN saves an instruction from the main ++ // loop and also ensures that at least one write always occurs to ++ // DATALEN, zero-extending it and allowing DATALEN64 to be used later. ++ sub $4*VL, DATALEN ++ jl .Lcrypt_loop_4x_done\@ ++ ++ // Load powers of the hash key. ++ vmovdqu8 OFFSETOFEND_H_POWERS-4*VL(KEY), H_POW4 ++ vmovdqu8 OFFSETOFEND_H_POWERS-3*VL(KEY), H_POW3 ++ vmovdqu8 OFFSETOFEND_H_POWERS-2*VL(KEY), H_POW2 ++ vmovdqu8 OFFSETOFEND_H_POWERS-1*VL(KEY), H_POW1 ++ ++ // Main loop: en/decrypt and hash 4 vectors at a time. ++ // ++ // When possible, interleave the AES encryption of the counter blocks ++ // with the GHASH update of the ciphertext blocks. This improves ++ // performance on many CPUs because the execution ports used by the VAES ++ // instructions often differ from those used by vpclmulqdq and other ++ // instructions used in GHASH. For example, many Intel CPUs dispatch ++ // vaesenc to ports 0 and 1 and vpclmulqdq to port 5. ++ // ++ // The interleaving is easiest to do during decryption, since during ++ // decryption the ciphertext blocks are immediately available. For ++ // encryption, instead encrypt the first set of blocks, then hash those ++ // blocks while encrypting the next set of blocks, repeat that as ++ // needed, and finally hash the last set of blocks. ++ ++.if \enc ++ // Encrypt the first 4 vectors of plaintext blocks. Leave the resulting ++ // ciphertext in GHASHDATA[0-3] for GHASH. ++ _ctr_begin_4x ++ lea 16(KEY), %rax ++1: ++ vbroadcasti32x4 (%rax), RNDKEY ++ _vaesenc_4x RNDKEY ++ add $16, %rax ++ cmp %rax, RNDKEYLAST_PTR ++ jne 1b ++ vpxord 0*VL(SRC), RNDKEYLAST, RNDKEYLAST0 ++ vpxord 1*VL(SRC), RNDKEYLAST, RNDKEYLAST1 ++ vpxord 2*VL(SRC), RNDKEYLAST, RNDKEYLAST2 ++ vpxord 3*VL(SRC), RNDKEYLAST, RNDKEYLAST3 ++ vaesenclast RNDKEYLAST0, V0, GHASHDATA0 ++ vaesenclast RNDKEYLAST1, V1, GHASHDATA1 ++ vaesenclast RNDKEYLAST2, V2, GHASHDATA2 ++ vaesenclast RNDKEYLAST3, V3, GHASHDATA3 ++ vmovdqu8 GHASHDATA0, 0*VL(DST) ++ vmovdqu8 GHASHDATA1, 1*VL(DST) ++ vmovdqu8 GHASHDATA2, 2*VL(DST) ++ vmovdqu8 GHASHDATA3, 3*VL(DST) ++ add $4*VL, SRC ++ add $4*VL, DST ++ sub $4*VL, DATALEN ++ jl .Lghash_last_ciphertext_4x\@ ++.endif ++ ++ // Cache as many additional AES round keys as possible. ++.irp i, 9,8,7,6,5 ++ vbroadcasti32x4 -\i*16(RNDKEYLAST_PTR), RNDKEY_M\i ++.endr ++ ++.Lcrypt_loop_4x\@: ++ ++ // If decrypting, load more ciphertext blocks into GHASHDATA[0-3]. If ++ // encrypting, GHASHDATA[0-3] already contain the previous ciphertext. ++.if !\enc ++ vmovdqu8 0*VL(SRC), GHASHDATA0 ++ vmovdqu8 1*VL(SRC), GHASHDATA1 ++ vmovdqu8 2*VL(SRC), GHASHDATA2 ++ vmovdqu8 3*VL(SRC), GHASHDATA3 ++.endif ++ ++ // Start the AES encryption of the counter blocks. ++ _ctr_begin_4x ++ cmp $24, AESKEYLEN ++ jl 128f // AES-128? ++ je 192f // AES-192? ++ // AES-256 ++ vbroadcasti32x4 -13*16(RNDKEYLAST_PTR), RNDKEY ++ _vaesenc_4x RNDKEY ++ vbroadcasti32x4 -12*16(RNDKEYLAST_PTR), RNDKEY ++ _vaesenc_4x RNDKEY ++192: ++ vbroadcasti32x4 -11*16(RNDKEYLAST_PTR), RNDKEY ++ _vaesenc_4x RNDKEY ++ vbroadcasti32x4 -10*16(RNDKEYLAST_PTR), RNDKEY ++ _vaesenc_4x RNDKEY ++128: ++ ++ // XOR the source data with the last round key, saving the result in ++ // RNDKEYLAST[0-3]. This reduces latency by taking advantage of the ++ // property vaesenclast(key, a) ^ b == vaesenclast(key ^ b, a). ++.if \enc ++ vpxord 0*VL(SRC), RNDKEYLAST, RNDKEYLAST0 ++ vpxord 1*VL(SRC), RNDKEYLAST, RNDKEYLAST1 ++ vpxord 2*VL(SRC), RNDKEYLAST, RNDKEYLAST2 ++ vpxord 3*VL(SRC), RNDKEYLAST, RNDKEYLAST3 ++.else ++ vpxord GHASHDATA0, RNDKEYLAST, RNDKEYLAST0 ++ vpxord GHASHDATA1, RNDKEYLAST, RNDKEYLAST1 ++ vpxord GHASHDATA2, RNDKEYLAST, RNDKEYLAST2 ++ vpxord GHASHDATA3, RNDKEYLAST, RNDKEYLAST3 ++.endif ++ ++ // Finish the AES encryption of the counter blocks in V0-V3, interleaved ++ // with the GHASH update of the ciphertext blocks in GHASHDATA[0-3]. ++.irp i, 9,8,7,6,5 ++ _vaesenc_4x RNDKEY_M\i ++ _ghash_step_4x (9 - \i) ++.endr ++.irp i, 4,3,2,1 ++ vbroadcasti32x4 -\i*16(RNDKEYLAST_PTR), RNDKEY ++ _vaesenc_4x RNDKEY ++ _ghash_step_4x (9 - \i) ++.endr ++ _ghash_step_4x 9 ++ ++ // Do the last AES round. This handles the XOR with the source data ++ // too, as per the optimization described above. ++ vaesenclast RNDKEYLAST0, V0, GHASHDATA0 ++ vaesenclast RNDKEYLAST1, V1, GHASHDATA1 ++ vaesenclast RNDKEYLAST2, V2, GHASHDATA2 ++ vaesenclast RNDKEYLAST3, V3, GHASHDATA3 ++ ++ // Store the en/decrypted data to DST. ++ vmovdqu8 GHASHDATA0, 0*VL(DST) ++ vmovdqu8 GHASHDATA1, 1*VL(DST) ++ vmovdqu8 GHASHDATA2, 2*VL(DST) ++ vmovdqu8 GHASHDATA3, 3*VL(DST) ++ ++ add $4*VL, SRC ++ add $4*VL, DST ++ sub $4*VL, DATALEN ++ jge .Lcrypt_loop_4x\@ ++ ++.if \enc ++.Lghash_last_ciphertext_4x\@: ++ // Update GHASH with the last set of ciphertext blocks. ++.irp i, 0,1,2,3,4,5,6,7,8,9 ++ _ghash_step_4x \i ++.endr ++.endif ++ ++.Lcrypt_loop_4x_done\@: ++ ++ // Undo the extra subtraction by 4*VL and check whether data remains. ++ add $4*VL, DATALEN ++ jz .Ldone\@ ++ ++ // The data length isn't a multiple of 4*VL. Process the remaining data ++ // of length 1 <= DATALEN < 4*VL, up to one vector (VL bytes) at a time. ++ // Going one vector at a time may seem inefficient compared to having ++ // separate code paths for each possible number of vectors remaining. ++ // However, using a loop keeps the code size down, and it performs ++ // surprising well; modern CPUs will start executing the next iteration ++ // before the previous one finishes and also predict the number of loop ++ // iterations. For a similar reason, we roll up the AES rounds. ++ // ++ // On the last iteration, the remaining length may be less than VL. ++ // Handle this using masking. ++ // ++ // Since there are enough key powers available for all remaining data, ++ // there is no need to do a GHASH reduction after each iteration. ++ // Instead, multiply each remaining block by its own key power, and only ++ // do a GHASH reduction at the very end. ++ ++ // Make POWERS_PTR point to the key powers [H^N, H^(N-1), ...] where N ++ // is the number of blocks that remain. ++ .set POWERS_PTR, LE_CTR_PTR // LE_CTR_PTR is free to be reused. ++ mov DATALEN, %eax ++ neg %rax ++ and $~15, %rax // -round_up(DATALEN, 16) ++ lea OFFSETOFEND_H_POWERS(KEY,%rax), POWERS_PTR ++ ++ // Start collecting the unreduced GHASH intermediate value LO, MI, HI. ++ .set LO, GHASHDATA0 ++ .set LO_XMM, GHASHDATA0_XMM ++ .set MI, GHASHDATA1 ++ .set MI_XMM, GHASHDATA1_XMM ++ .set HI, GHASHDATA2 ++ .set HI_XMM, GHASHDATA2_XMM ++ vpxor LO_XMM, LO_XMM, LO_XMM ++ vpxor MI_XMM, MI_XMM, MI_XMM ++ vpxor HI_XMM, HI_XMM, HI_XMM ++ ++.Lcrypt_loop_1x\@: ++ ++ // Select the appropriate mask for this iteration: all 1's if ++ // DATALEN >= VL, otherwise DATALEN 1's. Do this branchlessly using the ++ // bzhi instruction from BMI2. (This relies on DATALEN <= 255.) ++.if VL < 64 ++ mov $-1, %eax ++ bzhi DATALEN, %eax, %eax ++ kmovd %eax, %k1 ++.else ++ mov $-1, %rax ++ bzhi DATALEN64, %rax, %rax ++ kmovq %rax, %k1 ++.endif ++ ++ // Encrypt a vector of counter blocks. This does not need to be masked. ++ vpshufb BSWAP_MASK, LE_CTR, V0 ++ vpaddd LE_CTR_INC, LE_CTR, LE_CTR ++ vpxord RNDKEY0, V0, V0 ++ lea 16(KEY), %rax ++1: ++ vbroadcasti32x4 (%rax), RNDKEY ++ vaesenc RNDKEY, V0, V0 ++ add $16, %rax ++ cmp %rax, RNDKEYLAST_PTR ++ jne 1b ++ vaesenclast RNDKEYLAST, V0, V0 ++ ++ // XOR the data with the appropriate number of keystream bytes. ++ vmovdqu8 (SRC), V1{%k1}{z} ++ vpxord V1, V0, V0 ++ vmovdqu8 V0, (DST){%k1} ++ ++ // Update GHASH with the ciphertext block(s), without reducing. ++ // ++ // In the case of DATALEN < VL, the ciphertext is zero-padded to VL. ++ // (If decrypting, it's done by the above masked load. If encrypting, ++ // it's done by the below masked register-to-register move.) Note that ++ // if DATALEN <= VL - 16, there will be additional padding beyond the ++ // padding of the last block specified by GHASH itself; i.e., there may ++ // be whole block(s) that get processed by the GHASH multiplication and ++ // reduction instructions but should not actually be included in the ++ // GHASH. However, any such blocks are all-zeroes, and the values that ++ // they're multiplied with are also all-zeroes. Therefore they just add ++ // 0 * 0 = 0 to the final GHASH result, which makes no difference. ++ vmovdqu8 (POWERS_PTR), H_POW1 ++.if \enc ++ vmovdqu8 V0, V1{%k1}{z} ++.endif ++ vpshufb BSWAP_MASK, V1, V0 ++ vpxord GHASH_ACC, V0, V0 ++ _ghash_mul_noreduce H_POW1, V0, LO, MI, HI, GHASHDATA3, V1, V2, V3 ++ vpxor GHASH_ACC_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM ++ ++ add $VL, POWERS_PTR ++ add $VL, SRC ++ add $VL, DST ++ sub $VL, DATALEN ++ jg .Lcrypt_loop_1x\@ ++ ++ // Finally, do the GHASH reduction. ++ _ghash_reduce LO, MI, HI, GFPOLY, V0 ++ _horizontal_xor HI, HI_XMM, GHASH_ACC_XMM, %xmm0, %xmm1, %xmm2 ++ ++.Ldone\@: ++ // Store the updated GHASH accumulator back to memory. ++ vmovdqu GHASH_ACC_XMM, (GHASH_ACC_PTR) ++ ++ vzeroupper // This is needed after using ymm or zmm registers. ++ RET ++.endm ++ ++// void aes_gcm_enc_final_vaes_avx10(const struct aes_gcm_key_avx10 *key, ++// const u32 le_ctr[4], u8 ghash_acc[16], ++// u64 total_aadlen, u64 total_datalen); ++// bool aes_gcm_dec_final_vaes_avx10(const struct aes_gcm_key_avx10 *key, ++// const u32 le_ctr[4], ++// const u8 ghash_acc[16], ++// u64 total_aadlen, u64 total_datalen, ++// const u8 tag[16], int taglen); ++// ++// This macro generates one of the above two functions (with \enc selecting ++// which one). Both functions finish computing the GCM authentication tag by ++// updating GHASH with the lengths block and encrypting the GHASH accumulator. ++// |total_aadlen| and |total_datalen| must be the total length of the additional ++// authenticated data and the en/decrypted data in bytes, respectively. ++// ++// The encryption function then stores the full-length (16-byte) computed ++// authentication tag to |ghash_acc|. The decryption function instead loads the ++// expected authentication tag (the one that was transmitted) from the 16-byte ++// buffer |tag|, compares the first 4 <= |taglen| <= 16 bytes of it to the ++// computed tag in constant time, and returns true if and only if they match. ++.macro _aes_gcm_final enc ++ ++ // Function arguments ++ .set KEY, %rdi ++ .set LE_CTR_PTR, %rsi ++ .set GHASH_ACC_PTR, %rdx ++ .set TOTAL_AADLEN, %rcx ++ .set TOTAL_DATALEN, %r8 ++ .set TAG, %r9 ++ .set TAGLEN, %r10d // Originally at 8(%rsp) ++ ++ // Additional local variables. ++ // %rax, %xmm0-%xmm3, and %k1 are used as temporary registers. ++ .set AESKEYLEN, %r11d ++ .set AESKEYLEN64, %r11 ++ .set GFPOLY, %xmm4 ++ .set BSWAP_MASK, %xmm5 ++ .set LE_CTR, %xmm6 ++ .set GHASH_ACC, %xmm7 ++ .set H_POW1, %xmm8 ++ ++ // Load some constants. ++ vmovdqa .Lgfpoly(%rip), GFPOLY ++ vmovdqa .Lbswap_mask(%rip), BSWAP_MASK ++ ++ // Load the AES key length in bytes. ++ movl OFFSETOF_AESKEYLEN(KEY), AESKEYLEN ++ ++ // Set up a counter block with 1 in the low 32-bit word. This is the ++ // counter that produces the ciphertext needed to encrypt the auth tag. ++ // GFPOLY has 1 in the low word, so grab the 1 from there using a blend. ++ vpblendd $0xe, (LE_CTR_PTR), GFPOLY, LE_CTR ++ ++ // Build the lengths block and XOR it with the GHASH accumulator. ++ // Although the lengths block is defined as the AAD length followed by ++ // the en/decrypted data length, both in big-endian byte order, a byte ++ // reflection of the full block is needed because of the way we compute ++ // GHASH (see _ghash_mul_step). By using little-endian values in the ++ // opposite order, we avoid having to reflect any bytes here. ++ vmovq TOTAL_DATALEN, %xmm0 ++ vpinsrq $1, TOTAL_AADLEN, %xmm0, %xmm0 ++ vpsllq $3, %xmm0, %xmm0 // Bytes to bits ++ vpxor (GHASH_ACC_PTR), %xmm0, GHASH_ACC ++ ++ // Load the first hash key power (H^1), which is stored last. ++ vmovdqu8 OFFSETOFEND_H_POWERS-16(KEY), H_POW1 ++ ++.if !\enc ++ // Prepare a mask of TAGLEN one bits. ++ movl 8(%rsp), TAGLEN ++ mov $-1, %eax ++ bzhi TAGLEN, %eax, %eax ++ kmovd %eax, %k1 ++.endif ++ ++ // Make %rax point to the last AES round key for the chosen AES variant. ++ lea 6*16(KEY,AESKEYLEN64,4), %rax ++ ++ // Start the AES encryption of the counter block by swapping the counter ++ // block to big-endian and XOR-ing it with the zero-th AES round key. ++ vpshufb BSWAP_MASK, LE_CTR, %xmm0 ++ vpxor (KEY), %xmm0, %xmm0 ++ ++ // Complete the AES encryption and multiply GHASH_ACC by H^1. ++ // Interleave the AES and GHASH instructions to improve performance. ++ cmp $24, AESKEYLEN ++ jl 128f // AES-128? ++ je 192f // AES-192? ++ // AES-256 ++ vaesenc -13*16(%rax), %xmm0, %xmm0 ++ vaesenc -12*16(%rax), %xmm0, %xmm0 ++192: ++ vaesenc -11*16(%rax), %xmm0, %xmm0 ++ vaesenc -10*16(%rax), %xmm0, %xmm0 ++128: ++.irp i, 0,1,2,3,4,5,6,7,8 ++ _ghash_mul_step \i, H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \ ++ %xmm1, %xmm2, %xmm3 ++ vaesenc (\i-9)*16(%rax), %xmm0, %xmm0 ++.endr ++ _ghash_mul_step 9, H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \ ++ %xmm1, %xmm2, %xmm3 ++ ++ // Undo the byte reflection of the GHASH accumulator. ++ vpshufb BSWAP_MASK, GHASH_ACC, GHASH_ACC ++ ++ // Do the last AES round and XOR the resulting keystream block with the ++ // GHASH accumulator to produce the full computed authentication tag. ++ // ++ // Reduce latency by taking advantage of the property vaesenclast(key, ++ // a) ^ b == vaesenclast(key ^ b, a). I.e., XOR GHASH_ACC into the last ++ // round key, instead of XOR'ing the final AES output with GHASH_ACC. ++ // ++ // enc_final then returns the computed auth tag, while dec_final ++ // compares it with the transmitted one and returns a bool. To compare ++ // the tags, dec_final XORs them together and uses vptest to check ++ // whether the result is all-zeroes. This should be constant-time. ++ // dec_final applies the vaesenclast optimization to this additional ++ // value XOR'd too, using vpternlogd to XOR the last round key, GHASH ++ // accumulator, and transmitted auth tag together in one instruction. ++.if \enc ++ vpxor (%rax), GHASH_ACC, %xmm1 ++ vaesenclast %xmm1, %xmm0, GHASH_ACC ++ vmovdqu GHASH_ACC, (GHASH_ACC_PTR) ++.else ++ vmovdqu (TAG), %xmm1 ++ vpternlogd $0x96, (%rax), GHASH_ACC, %xmm1 ++ vaesenclast %xmm1, %xmm0, %xmm0 ++ xor %eax, %eax ++ vmovdqu8 %xmm0, %xmm0{%k1}{z} // Truncate to TAGLEN bytes ++ vptest %xmm0, %xmm0 ++ sete %al ++.endif ++ // No need for vzeroupper here, since only used xmm registers were used. ++ RET ++.endm ++ ++_set_veclen 32 ++SYM_FUNC_START(aes_gcm_precompute_vaes_avx10_256) ++ _aes_gcm_precompute ++SYM_FUNC_END(aes_gcm_precompute_vaes_avx10_256) ++SYM_FUNC_START(aes_gcm_enc_update_vaes_avx10_256) ++ _aes_gcm_update 1 ++SYM_FUNC_END(aes_gcm_enc_update_vaes_avx10_256) ++SYM_FUNC_START(aes_gcm_dec_update_vaes_avx10_256) ++ _aes_gcm_update 0 ++SYM_FUNC_END(aes_gcm_dec_update_vaes_avx10_256) ++ ++_set_veclen 64 ++SYM_FUNC_START(aes_gcm_precompute_vaes_avx10_512) ++ _aes_gcm_precompute ++SYM_FUNC_END(aes_gcm_precompute_vaes_avx10_512) ++SYM_FUNC_START(aes_gcm_enc_update_vaes_avx10_512) ++ _aes_gcm_update 1 ++SYM_FUNC_END(aes_gcm_enc_update_vaes_avx10_512) ++SYM_FUNC_START(aes_gcm_dec_update_vaes_avx10_512) ++ _aes_gcm_update 0 ++SYM_FUNC_END(aes_gcm_dec_update_vaes_avx10_512) ++ ++// void aes_gcm_aad_update_vaes_avx10(const struct aes_gcm_key_avx10 *key, ++// u8 ghash_acc[16], ++// const u8 *aad, int aadlen); ++// ++// This function processes the AAD (Additional Authenticated Data) in GCM. ++// Using the key |key|, it updates the GHASH accumulator |ghash_acc| with the ++// data given by |aad| and |aadlen|. |key->ghash_key_powers| must have been ++// initialized. On the first call, |ghash_acc| must be all zeroes. |aadlen| ++// must be a multiple of 16, except on the last call where it can be any length. ++// The caller must do any buffering needed to ensure this. ++// ++// AES-GCM is almost always used with small amounts of AAD, less than 32 bytes. ++// Therefore, for AAD processing we currently only provide this implementation ++// which uses 256-bit vectors (ymm registers) and only has a 1x-wide loop. This ++// keeps the code size down, and it enables some micro-optimizations, e.g. using ++// VEX-coded instructions instead of EVEX-coded to save some instruction bytes. ++// To optimize for large amounts of AAD, we could implement a 4x-wide loop and ++// provide a version using 512-bit vectors, but that doesn't seem to be useful. ++SYM_FUNC_START(aes_gcm_aad_update_vaes_avx10) ++ ++ // Function arguments ++ .set KEY, %rdi ++ .set GHASH_ACC_PTR, %rsi ++ .set AAD, %rdx ++ .set AADLEN, %ecx ++ .set AADLEN64, %rcx // Zero-extend AADLEN before using! ++ ++ // Additional local variables. ++ // %rax, %ymm0-%ymm3, and %k1 are used as temporary registers. ++ .set BSWAP_MASK, %ymm4 ++ .set GFPOLY, %ymm5 ++ .set GHASH_ACC, %ymm6 ++ .set GHASH_ACC_XMM, %xmm6 ++ .set H_POW1, %ymm7 ++ ++ // Load some constants. ++ vbroadcasti128 .Lbswap_mask(%rip), BSWAP_MASK ++ vbroadcasti128 .Lgfpoly(%rip), GFPOLY ++ ++ // Load the GHASH accumulator. ++ vmovdqu (GHASH_ACC_PTR), GHASH_ACC_XMM ++ ++ // Update GHASH with 32 bytes of AAD at a time. ++ // ++ // Pre-subtracting 32 from AADLEN saves an instruction from the loop and ++ // also ensures that at least one write always occurs to AADLEN, ++ // zero-extending it and allowing AADLEN64 to be used later. ++ sub $32, AADLEN ++ jl .Laad_loop_1x_done ++ vmovdqu8 OFFSETOFEND_H_POWERS-32(KEY), H_POW1 // [H^2, H^1] ++.Laad_loop_1x: ++ vmovdqu (AAD), %ymm0 ++ vpshufb BSWAP_MASK, %ymm0, %ymm0 ++ vpxor %ymm0, GHASH_ACC, GHASH_ACC ++ _ghash_mul H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \ ++ %ymm0, %ymm1, %ymm2 ++ vextracti128 $1, GHASH_ACC, %xmm0 ++ vpxor %xmm0, GHASH_ACC_XMM, GHASH_ACC_XMM ++ add $32, AAD ++ sub $32, AADLEN ++ jge .Laad_loop_1x ++.Laad_loop_1x_done: ++ add $32, AADLEN ++ jz .Laad_done ++ ++ // Update GHASH with the remaining 1 <= AADLEN < 32 bytes of AAD. ++ mov $-1, %eax ++ bzhi AADLEN, %eax, %eax ++ kmovd %eax, %k1 ++ vmovdqu8 (AAD), %ymm0{%k1}{z} ++ neg AADLEN64 ++ and $~15, AADLEN64 // -round_up(AADLEN, 16) ++ vmovdqu8 OFFSETOFEND_H_POWERS(KEY,AADLEN64), H_POW1 ++ vpshufb BSWAP_MASK, %ymm0, %ymm0 ++ vpxor %ymm0, GHASH_ACC, GHASH_ACC ++ _ghash_mul H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \ ++ %ymm0, %ymm1, %ymm2 ++ vextracti128 $1, GHASH_ACC, %xmm0 ++ vpxor %xmm0, GHASH_ACC_XMM, GHASH_ACC_XMM ++ ++.Laad_done: ++ // Store the updated GHASH accumulator back to memory. ++ vmovdqu GHASH_ACC_XMM, (GHASH_ACC_PTR) ++ ++ vzeroupper // This is needed after using ymm or zmm registers. ++ RET ++SYM_FUNC_END(aes_gcm_aad_update_vaes_avx10) ++ ++SYM_FUNC_START(aes_gcm_enc_final_vaes_avx10) ++ _aes_gcm_final 1 ++SYM_FUNC_END(aes_gcm_enc_final_vaes_avx10) ++SYM_FUNC_START(aes_gcm_dec_final_vaes_avx10) ++ _aes_gcm_final 0 ++SYM_FUNC_END(aes_gcm_dec_final_vaes_avx10) +diff --git a/arch/x86/crypto/aesni-intel_asm.S b/arch/x86/crypto/aesni-intel_asm.S +index 39066b57a70e..eb153eff9331 100644 +--- a/arch/x86/crypto/aesni-intel_asm.S ++++ b/arch/x86/crypto/aesni-intel_asm.S +@@ -10,16 +10,7 @@ + * Vinodh Gopal + * Kahraman Akdemir + * +- * Added RFC4106 AES-GCM support for 128-bit keys under the AEAD +- * interface for 64-bit kernels. +- * Authors: Erdinc Ozturk (erdinc.ozturk@intel.com) +- * Aidan O'Mahony (aidan.o.mahony@intel.com) +- * Adrian Hoban +- * James Guilford (james.guilford@intel.com) +- * Gabriele Paoloni +- * Tadeusz Struk (tadeusz.struk@intel.com) +- * Wajdi Feghali (wajdi.k.feghali@intel.com) +- * Copyright (c) 2010, Intel Corporation. ++ * Copyright (c) 2010, Intel Corporation. + * + * Ported x86_64 version to x86: + * Author: Mathias Krause +@@ -27,95 +18,6 @@ + + #include + #include +-#include +- +-/* +- * The following macros are used to move an (un)aligned 16 byte value to/from +- * an XMM register. This can done for either FP or integer values, for FP use +- * movaps (move aligned packed single) or integer use movdqa (move double quad +- * aligned). It doesn't make a performance difference which instruction is used +- * since Nehalem (original Core i7) was released. However, the movaps is a byte +- * shorter, so that is the one we'll use for now. (same for unaligned). +- */ +-#define MOVADQ movaps +-#define MOVUDQ movups +- +-#ifdef __x86_64__ +- +-# constants in mergeable sections, linker can reorder and merge +-.section .rodata.cst16.POLY, "aM", @progbits, 16 +-.align 16 +-POLY: .octa 0xC2000000000000000000000000000001 +-.section .rodata.cst16.TWOONE, "aM", @progbits, 16 +-.align 16 +-TWOONE: .octa 0x00000001000000000000000000000001 +- +-.section .rodata.cst16.SHUF_MASK, "aM", @progbits, 16 +-.align 16 +-SHUF_MASK: .octa 0x000102030405060708090A0B0C0D0E0F +-.section .rodata.cst16.MASK1, "aM", @progbits, 16 +-.align 16 +-MASK1: .octa 0x0000000000000000ffffffffffffffff +-.section .rodata.cst16.MASK2, "aM", @progbits, 16 +-.align 16 +-MASK2: .octa 0xffffffffffffffff0000000000000000 +-.section .rodata.cst16.ONE, "aM", @progbits, 16 +-.align 16 +-ONE: .octa 0x00000000000000000000000000000001 +-.section .rodata.cst16.F_MIN_MASK, "aM", @progbits, 16 +-.align 16 +-F_MIN_MASK: .octa 0xf1f2f3f4f5f6f7f8f9fafbfcfdfeff0 +-.section .rodata.cst16.dec, "aM", @progbits, 16 +-.align 16 +-dec: .octa 0x1 +-.section .rodata.cst16.enc, "aM", @progbits, 16 +-.align 16 +-enc: .octa 0x2 +- +-# order of these constants should not change. +-# more specifically, ALL_F should follow SHIFT_MASK, +-# and zero should follow ALL_F +-.section .rodata, "a", @progbits +-.align 16 +-SHIFT_MASK: .octa 0x0f0e0d0c0b0a09080706050403020100 +-ALL_F: .octa 0xffffffffffffffffffffffffffffffff +- .octa 0x00000000000000000000000000000000 +- +-.text +- +-#define AadHash 16*0 +-#define AadLen 16*1 +-#define InLen (16*1)+8 +-#define PBlockEncKey 16*2 +-#define OrigIV 16*3 +-#define CurCount 16*4 +-#define PBlockLen 16*5 +-#define HashKey 16*6 // store HashKey <<1 mod poly here +-#define HashKey_2 16*7 // store HashKey^2 <<1 mod poly here +-#define HashKey_3 16*8 // store HashKey^3 <<1 mod poly here +-#define HashKey_4 16*9 // store HashKey^4 <<1 mod poly here +-#define HashKey_k 16*10 // store XOR of High 64 bits and Low 64 +- // bits of HashKey <<1 mod poly here +- //(for Karatsuba purposes) +-#define HashKey_2_k 16*11 // store XOR of High 64 bits and Low 64 +- // bits of HashKey^2 <<1 mod poly here +- // (for Karatsuba purposes) +-#define HashKey_3_k 16*12 // store XOR of High 64 bits and Low 64 +- // bits of HashKey^3 <<1 mod poly here +- // (for Karatsuba purposes) +-#define HashKey_4_k 16*13 // store XOR of High 64 bits and Low 64 +- // bits of HashKey^4 <<1 mod poly here +- // (for Karatsuba purposes) +- +-#define arg1 rdi +-#define arg2 rsi +-#define arg3 rdx +-#define arg4 rcx +-#define arg5 r8 +-#define arg6 r9 +-#define keysize 2*15*16(%arg1) +-#endif +- + + #define STATE1 %xmm0 + #define STATE2 %xmm4 +@@ -162,1409 +64,6 @@ ALL_F: .octa 0xffffffffffffffffffffffffffffffff + #define TKEYP T1 + #endif + +-.macro FUNC_SAVE +- push %r12 +- push %r13 +- push %r14 +-# +-# states of %xmm registers %xmm6:%xmm15 not saved +-# all %xmm registers are clobbered +-# +-.endm +- +- +-.macro FUNC_RESTORE +- pop %r14 +- pop %r13 +- pop %r12 +-.endm +- +-# Precompute hashkeys. +-# Input: Hash subkey. +-# Output: HashKeys stored in gcm_context_data. Only needs to be called +-# once per key. +-# clobbers r12, and tmp xmm registers. +-.macro PRECOMPUTE SUBKEY TMP1 TMP2 TMP3 TMP4 TMP5 TMP6 TMP7 +- mov \SUBKEY, %r12 +- movdqu (%r12), \TMP3 +- movdqa SHUF_MASK(%rip), \TMP2 +- pshufb \TMP2, \TMP3 +- +- # precompute HashKey<<1 mod poly from the HashKey (required for GHASH) +- +- movdqa \TMP3, \TMP2 +- psllq $1, \TMP3 +- psrlq $63, \TMP2 +- movdqa \TMP2, \TMP1 +- pslldq $8, \TMP2 +- psrldq $8, \TMP1 +- por \TMP2, \TMP3 +- +- # reduce HashKey<<1 +- +- pshufd $0x24, \TMP1, \TMP2 +- pcmpeqd TWOONE(%rip), \TMP2 +- pand POLY(%rip), \TMP2 +- pxor \TMP2, \TMP3 +- movdqu \TMP3, HashKey(%arg2) +- +- movdqa \TMP3, \TMP5 +- pshufd $78, \TMP3, \TMP1 +- pxor \TMP3, \TMP1 +- movdqu \TMP1, HashKey_k(%arg2) +- +- GHASH_MUL \TMP5, \TMP3, \TMP1, \TMP2, \TMP4, \TMP6, \TMP7 +-# TMP5 = HashKey^2<<1 (mod poly) +- movdqu \TMP5, HashKey_2(%arg2) +-# HashKey_2 = HashKey^2<<1 (mod poly) +- pshufd $78, \TMP5, \TMP1 +- pxor \TMP5, \TMP1 +- movdqu \TMP1, HashKey_2_k(%arg2) +- +- GHASH_MUL \TMP5, \TMP3, \TMP1, \TMP2, \TMP4, \TMP6, \TMP7 +-# TMP5 = HashKey^3<<1 (mod poly) +- movdqu \TMP5, HashKey_3(%arg2) +- pshufd $78, \TMP5, \TMP1 +- pxor \TMP5, \TMP1 +- movdqu \TMP1, HashKey_3_k(%arg2) +- +- GHASH_MUL \TMP5, \TMP3, \TMP1, \TMP2, \TMP4, \TMP6, \TMP7 +-# TMP5 = HashKey^3<<1 (mod poly) +- movdqu \TMP5, HashKey_4(%arg2) +- pshufd $78, \TMP5, \TMP1 +- pxor \TMP5, \TMP1 +- movdqu \TMP1, HashKey_4_k(%arg2) +-.endm +- +-# GCM_INIT initializes a gcm_context struct to prepare for encoding/decoding. +-# Clobbers rax, r10-r13 and xmm0-xmm6, %xmm13 +-.macro GCM_INIT Iv SUBKEY AAD AADLEN +- mov \AADLEN, %r11 +- mov %r11, AadLen(%arg2) # ctx_data.aad_length = aad_length +- xor %r11d, %r11d +- mov %r11, InLen(%arg2) # ctx_data.in_length = 0 +- mov %r11, PBlockLen(%arg2) # ctx_data.partial_block_length = 0 +- mov %r11, PBlockEncKey(%arg2) # ctx_data.partial_block_enc_key = 0 +- mov \Iv, %rax +- movdqu (%rax), %xmm0 +- movdqu %xmm0, OrigIV(%arg2) # ctx_data.orig_IV = iv +- +- movdqa SHUF_MASK(%rip), %xmm2 +- pshufb %xmm2, %xmm0 +- movdqu %xmm0, CurCount(%arg2) # ctx_data.current_counter = iv +- +- PRECOMPUTE \SUBKEY, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7 +- movdqu HashKey(%arg2), %xmm13 +- +- CALC_AAD_HASH %xmm13, \AAD, \AADLEN, %xmm0, %xmm1, %xmm2, %xmm3, \ +- %xmm4, %xmm5, %xmm6 +-.endm +- +-# GCM_ENC_DEC Encodes/Decodes given data. Assumes that the passed gcm_context +-# struct has been initialized by GCM_INIT. +-# Requires the input data be at least 1 byte long because of READ_PARTIAL_BLOCK +-# Clobbers rax, r10-r13, and xmm0-xmm15 +-.macro GCM_ENC_DEC operation +- movdqu AadHash(%arg2), %xmm8 +- movdqu HashKey(%arg2), %xmm13 +- add %arg5, InLen(%arg2) +- +- xor %r11d, %r11d # initialise the data pointer offset as zero +- PARTIAL_BLOCK %arg3 %arg4 %arg5 %r11 %xmm8 \operation +- +- sub %r11, %arg5 # sub partial block data used +- mov %arg5, %r13 # save the number of bytes +- +- and $-16, %r13 # %r13 = %r13 - (%r13 mod 16) +- mov %r13, %r12 +- # Encrypt/Decrypt first few blocks +- +- and $(3<<4), %r12 +- jz .L_initial_num_blocks_is_0_\@ +- cmp $(2<<4), %r12 +- jb .L_initial_num_blocks_is_1_\@ +- je .L_initial_num_blocks_is_2_\@ +-.L_initial_num_blocks_is_3_\@: +- INITIAL_BLOCKS_ENC_DEC %xmm9, %xmm10, %xmm13, %xmm11, %xmm12, %xmm0, \ +-%xmm1, %xmm2, %xmm3, %xmm4, %xmm8, %xmm5, %xmm6, 5, 678, \operation +- sub $48, %r13 +- jmp .L_initial_blocks_\@ +-.L_initial_num_blocks_is_2_\@: +- INITIAL_BLOCKS_ENC_DEC %xmm9, %xmm10, %xmm13, %xmm11, %xmm12, %xmm0, \ +-%xmm1, %xmm2, %xmm3, %xmm4, %xmm8, %xmm5, %xmm6, 6, 78, \operation +- sub $32, %r13 +- jmp .L_initial_blocks_\@ +-.L_initial_num_blocks_is_1_\@: +- INITIAL_BLOCKS_ENC_DEC %xmm9, %xmm10, %xmm13, %xmm11, %xmm12, %xmm0, \ +-%xmm1, %xmm2, %xmm3, %xmm4, %xmm8, %xmm5, %xmm6, 7, 8, \operation +- sub $16, %r13 +- jmp .L_initial_blocks_\@ +-.L_initial_num_blocks_is_0_\@: +- INITIAL_BLOCKS_ENC_DEC %xmm9, %xmm10, %xmm13, %xmm11, %xmm12, %xmm0, \ +-%xmm1, %xmm2, %xmm3, %xmm4, %xmm8, %xmm5, %xmm6, 8, 0, \operation +-.L_initial_blocks_\@: +- +- # Main loop - Encrypt/Decrypt remaining blocks +- +- test %r13, %r13 +- je .L_zero_cipher_left_\@ +- sub $64, %r13 +- je .L_four_cipher_left_\@ +-.L_crypt_by_4_\@: +- GHASH_4_ENCRYPT_4_PARALLEL_\operation %xmm9, %xmm10, %xmm11, %xmm12, \ +- %xmm13, %xmm14, %xmm0, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, \ +- %xmm7, %xmm8, enc +- add $64, %r11 +- sub $64, %r13 +- jne .L_crypt_by_4_\@ +-.L_four_cipher_left_\@: +- GHASH_LAST_4 %xmm9, %xmm10, %xmm11, %xmm12, %xmm13, %xmm14, \ +-%xmm15, %xmm1, %xmm2, %xmm3, %xmm4, %xmm8 +-.L_zero_cipher_left_\@: +- movdqu %xmm8, AadHash(%arg2) +- movdqu %xmm0, CurCount(%arg2) +- +- mov %arg5, %r13 +- and $15, %r13 # %r13 = arg5 (mod 16) +- je .L_multiple_of_16_bytes_\@ +- +- mov %r13, PBlockLen(%arg2) +- +- # Handle the last <16 Byte block separately +- paddd ONE(%rip), %xmm0 # INCR CNT to get Yn +- movdqu %xmm0, CurCount(%arg2) +- movdqa SHUF_MASK(%rip), %xmm10 +- pshufb %xmm10, %xmm0 +- +- ENCRYPT_SINGLE_BLOCK %xmm0, %xmm1 # Encrypt(K, Yn) +- movdqu %xmm0, PBlockEncKey(%arg2) +- +- cmp $16, %arg5 +- jge .L_large_enough_update_\@ +- +- lea (%arg4,%r11,1), %r10 +- mov %r13, %r12 +- READ_PARTIAL_BLOCK %r10 %r12 %xmm2 %xmm1 +- jmp .L_data_read_\@ +- +-.L_large_enough_update_\@: +- sub $16, %r11 +- add %r13, %r11 +- +- # receive the last <16 Byte block +- movdqu (%arg4, %r11, 1), %xmm1 +- +- sub %r13, %r11 +- add $16, %r11 +- +- lea SHIFT_MASK+16(%rip), %r12 +- # adjust the shuffle mask pointer to be able to shift 16-r13 bytes +- # (r13 is the number of bytes in plaintext mod 16) +- sub %r13, %r12 +- # get the appropriate shuffle mask +- movdqu (%r12), %xmm2 +- # shift right 16-r13 bytes +- pshufb %xmm2, %xmm1 +- +-.L_data_read_\@: +- lea ALL_F+16(%rip), %r12 +- sub %r13, %r12 +- +-.ifc \operation, dec +- movdqa %xmm1, %xmm2 +-.endif +- pxor %xmm1, %xmm0 # XOR Encrypt(K, Yn) +- movdqu (%r12), %xmm1 +- # get the appropriate mask to mask out top 16-r13 bytes of xmm0 +- pand %xmm1, %xmm0 # mask out top 16-r13 bytes of xmm0 +-.ifc \operation, dec +- pand %xmm1, %xmm2 +- movdqa SHUF_MASK(%rip), %xmm10 +- pshufb %xmm10 ,%xmm2 +- +- pxor %xmm2, %xmm8 +-.else +- movdqa SHUF_MASK(%rip), %xmm10 +- pshufb %xmm10,%xmm0 +- +- pxor %xmm0, %xmm8 +-.endif +- +- movdqu %xmm8, AadHash(%arg2) +-.ifc \operation, enc +- # GHASH computation for the last <16 byte block +- movdqa SHUF_MASK(%rip), %xmm10 +- # shuffle xmm0 back to output as ciphertext +- pshufb %xmm10, %xmm0 +-.endif +- +- # Output %r13 bytes +- movq %xmm0, %rax +- cmp $8, %r13 +- jle .L_less_than_8_bytes_left_\@ +- mov %rax, (%arg3 , %r11, 1) +- add $8, %r11 +- psrldq $8, %xmm0 +- movq %xmm0, %rax +- sub $8, %r13 +-.L_less_than_8_bytes_left_\@: +- mov %al, (%arg3, %r11, 1) +- add $1, %r11 +- shr $8, %rax +- sub $1, %r13 +- jne .L_less_than_8_bytes_left_\@ +-.L_multiple_of_16_bytes_\@: +-.endm +- +-# GCM_COMPLETE Finishes update of tag of last partial block +-# Output: Authorization Tag (AUTH_TAG) +-# Clobbers rax, r10-r12, and xmm0, xmm1, xmm5-xmm15 +-.macro GCM_COMPLETE AUTHTAG AUTHTAGLEN +- movdqu AadHash(%arg2), %xmm8 +- movdqu HashKey(%arg2), %xmm13 +- +- mov PBlockLen(%arg2), %r12 +- +- test %r12, %r12 +- je .L_partial_done\@ +- +- GHASH_MUL %xmm8, %xmm13, %xmm9, %xmm10, %xmm11, %xmm5, %xmm6 +- +-.L_partial_done\@: +- mov AadLen(%arg2), %r12 # %r13 = aadLen (number of bytes) +- shl $3, %r12 # convert into number of bits +- movd %r12d, %xmm15 # len(A) in %xmm15 +- mov InLen(%arg2), %r12 +- shl $3, %r12 # len(C) in bits (*128) +- movq %r12, %xmm1 +- +- pslldq $8, %xmm15 # %xmm15 = len(A)||0x0000000000000000 +- pxor %xmm1, %xmm15 # %xmm15 = len(A)||len(C) +- pxor %xmm15, %xmm8 +- GHASH_MUL %xmm8, %xmm13, %xmm9, %xmm10, %xmm11, %xmm5, %xmm6 +- # final GHASH computation +- movdqa SHUF_MASK(%rip), %xmm10 +- pshufb %xmm10, %xmm8 +- +- movdqu OrigIV(%arg2), %xmm0 # %xmm0 = Y0 +- ENCRYPT_SINGLE_BLOCK %xmm0, %xmm1 # E(K, Y0) +- pxor %xmm8, %xmm0 +-.L_return_T_\@: +- mov \AUTHTAG, %r10 # %r10 = authTag +- mov \AUTHTAGLEN, %r11 # %r11 = auth_tag_len +- cmp $16, %r11 +- je .L_T_16_\@ +- cmp $8, %r11 +- jl .L_T_4_\@ +-.L_T_8_\@: +- movq %xmm0, %rax +- mov %rax, (%r10) +- add $8, %r10 +- sub $8, %r11 +- psrldq $8, %xmm0 +- test %r11, %r11 +- je .L_return_T_done_\@ +-.L_T_4_\@: +- movd %xmm0, %eax +- mov %eax, (%r10) +- add $4, %r10 +- sub $4, %r11 +- psrldq $4, %xmm0 +- test %r11, %r11 +- je .L_return_T_done_\@ +-.L_T_123_\@: +- movd %xmm0, %eax +- cmp $2, %r11 +- jl .L_T_1_\@ +- mov %ax, (%r10) +- cmp $2, %r11 +- je .L_return_T_done_\@ +- add $2, %r10 +- sar $16, %eax +-.L_T_1_\@: +- mov %al, (%r10) +- jmp .L_return_T_done_\@ +-.L_T_16_\@: +- movdqu %xmm0, (%r10) +-.L_return_T_done_\@: +-.endm +- +-#ifdef __x86_64__ +-/* GHASH_MUL MACRO to implement: Data*HashKey mod (128,127,126,121,0) +-* +-* +-* Input: A and B (128-bits each, bit-reflected) +-* Output: C = A*B*x mod poly, (i.e. >>1 ) +-* To compute GH = GH*HashKey mod poly, give HK = HashKey<<1 mod poly as input +-* GH = GH * HK * x mod poly which is equivalent to GH*HashKey mod poly. +-* +-*/ +-.macro GHASH_MUL GH HK TMP1 TMP2 TMP3 TMP4 TMP5 +- movdqa \GH, \TMP1 +- pshufd $78, \GH, \TMP2 +- pshufd $78, \HK, \TMP3 +- pxor \GH, \TMP2 # TMP2 = a1+a0 +- pxor \HK, \TMP3 # TMP3 = b1+b0 +- pclmulqdq $0x11, \HK, \TMP1 # TMP1 = a1*b1 +- pclmulqdq $0x00, \HK, \GH # GH = a0*b0 +- pclmulqdq $0x00, \TMP3, \TMP2 # TMP2 = (a0+a1)*(b1+b0) +- pxor \GH, \TMP2 +- pxor \TMP1, \TMP2 # TMP2 = (a0*b0)+(a1*b0) +- movdqa \TMP2, \TMP3 +- pslldq $8, \TMP3 # left shift TMP3 2 DWs +- psrldq $8, \TMP2 # right shift TMP2 2 DWs +- pxor \TMP3, \GH +- pxor \TMP2, \TMP1 # TMP2:GH holds the result of GH*HK +- +- # first phase of the reduction +- +- movdqa \GH, \TMP2 +- movdqa \GH, \TMP3 +- movdqa \GH, \TMP4 # copy GH into TMP2,TMP3 and TMP4 +- # in in order to perform +- # independent shifts +- pslld $31, \TMP2 # packed right shift <<31 +- pslld $30, \TMP3 # packed right shift <<30 +- pslld $25, \TMP4 # packed right shift <<25 +- pxor \TMP3, \TMP2 # xor the shifted versions +- pxor \TMP4, \TMP2 +- movdqa \TMP2, \TMP5 +- psrldq $4, \TMP5 # right shift TMP5 1 DW +- pslldq $12, \TMP2 # left shift TMP2 3 DWs +- pxor \TMP2, \GH +- +- # second phase of the reduction +- +- movdqa \GH,\TMP2 # copy GH into TMP2,TMP3 and TMP4 +- # in in order to perform +- # independent shifts +- movdqa \GH,\TMP3 +- movdqa \GH,\TMP4 +- psrld $1,\TMP2 # packed left shift >>1 +- psrld $2,\TMP3 # packed left shift >>2 +- psrld $7,\TMP4 # packed left shift >>7 +- pxor \TMP3,\TMP2 # xor the shifted versions +- pxor \TMP4,\TMP2 +- pxor \TMP5, \TMP2 +- pxor \TMP2, \GH +- pxor \TMP1, \GH # result is in TMP1 +-.endm +- +-# Reads DLEN bytes starting at DPTR and stores in XMMDst +-# where 0 < DLEN < 16 +-# Clobbers %rax, DLEN and XMM1 +-.macro READ_PARTIAL_BLOCK DPTR DLEN XMM1 XMMDst +- cmp $8, \DLEN +- jl .L_read_lt8_\@ +- mov (\DPTR), %rax +- movq %rax, \XMMDst +- sub $8, \DLEN +- jz .L_done_read_partial_block_\@ +- xor %eax, %eax +-.L_read_next_byte_\@: +- shl $8, %rax +- mov 7(\DPTR, \DLEN, 1), %al +- dec \DLEN +- jnz .L_read_next_byte_\@ +- movq %rax, \XMM1 +- pslldq $8, \XMM1 +- por \XMM1, \XMMDst +- jmp .L_done_read_partial_block_\@ +-.L_read_lt8_\@: +- xor %eax, %eax +-.L_read_next_byte_lt8_\@: +- shl $8, %rax +- mov -1(\DPTR, \DLEN, 1), %al +- dec \DLEN +- jnz .L_read_next_byte_lt8_\@ +- movq %rax, \XMMDst +-.L_done_read_partial_block_\@: +-.endm +- +-# CALC_AAD_HASH: Calculates the hash of the data which will not be encrypted. +-# clobbers r10-11, xmm14 +-.macro CALC_AAD_HASH HASHKEY AAD AADLEN TMP1 TMP2 TMP3 TMP4 TMP5 \ +- TMP6 TMP7 +- MOVADQ SHUF_MASK(%rip), %xmm14 +- mov \AAD, %r10 # %r10 = AAD +- mov \AADLEN, %r11 # %r11 = aadLen +- pxor \TMP7, \TMP7 +- pxor \TMP6, \TMP6 +- +- cmp $16, %r11 +- jl .L_get_AAD_rest\@ +-.L_get_AAD_blocks\@: +- movdqu (%r10), \TMP7 +- pshufb %xmm14, \TMP7 # byte-reflect the AAD data +- pxor \TMP7, \TMP6 +- GHASH_MUL \TMP6, \HASHKEY, \TMP1, \TMP2, \TMP3, \TMP4, \TMP5 +- add $16, %r10 +- sub $16, %r11 +- cmp $16, %r11 +- jge .L_get_AAD_blocks\@ +- +- movdqu \TMP6, \TMP7 +- +- /* read the last <16B of AAD */ +-.L_get_AAD_rest\@: +- test %r11, %r11 +- je .L_get_AAD_done\@ +- +- READ_PARTIAL_BLOCK %r10, %r11, \TMP1, \TMP7 +- pshufb %xmm14, \TMP7 # byte-reflect the AAD data +- pxor \TMP6, \TMP7 +- GHASH_MUL \TMP7, \HASHKEY, \TMP1, \TMP2, \TMP3, \TMP4, \TMP5 +- movdqu \TMP7, \TMP6 +- +-.L_get_AAD_done\@: +- movdqu \TMP6, AadHash(%arg2) +-.endm +- +-# PARTIAL_BLOCK: Handles encryption/decryption and the tag partial blocks +-# between update calls. +-# Requires the input data be at least 1 byte long due to READ_PARTIAL_BLOCK +-# Outputs encrypted bytes, and updates hash and partial info in gcm_data_context +-# Clobbers rax, r10, r12, r13, xmm0-6, xmm9-13 +-.macro PARTIAL_BLOCK CYPH_PLAIN_OUT PLAIN_CYPH_IN PLAIN_CYPH_LEN DATA_OFFSET \ +- AAD_HASH operation +- mov PBlockLen(%arg2), %r13 +- test %r13, %r13 +- je .L_partial_block_done_\@ # Leave Macro if no partial blocks +- # Read in input data without over reading +- cmp $16, \PLAIN_CYPH_LEN +- jl .L_fewer_than_16_bytes_\@ +- movups (\PLAIN_CYPH_IN), %xmm1 # If more than 16 bytes, just fill xmm +- jmp .L_data_read_\@ +- +-.L_fewer_than_16_bytes_\@: +- lea (\PLAIN_CYPH_IN, \DATA_OFFSET, 1), %r10 +- mov \PLAIN_CYPH_LEN, %r12 +- READ_PARTIAL_BLOCK %r10 %r12 %xmm0 %xmm1 +- +- mov PBlockLen(%arg2), %r13 +- +-.L_data_read_\@: # Finished reading in data +- +- movdqu PBlockEncKey(%arg2), %xmm9 +- movdqu HashKey(%arg2), %xmm13 +- +- lea SHIFT_MASK(%rip), %r12 +- +- # adjust the shuffle mask pointer to be able to shift r13 bytes +- # r16-r13 is the number of bytes in plaintext mod 16) +- add %r13, %r12 +- movdqu (%r12), %xmm2 # get the appropriate shuffle mask +- pshufb %xmm2, %xmm9 # shift right r13 bytes +- +-.ifc \operation, dec +- movdqa %xmm1, %xmm3 +- pxor %xmm1, %xmm9 # Ciphertext XOR E(K, Yn) +- +- mov \PLAIN_CYPH_LEN, %r10 +- add %r13, %r10 +- # Set r10 to be the amount of data left in CYPH_PLAIN_IN after filling +- sub $16, %r10 +- # Determine if partial block is not being filled and +- # shift mask accordingly +- jge .L_no_extra_mask_1_\@ +- sub %r10, %r12 +-.L_no_extra_mask_1_\@: +- +- movdqu ALL_F-SHIFT_MASK(%r12), %xmm1 +- # get the appropriate mask to mask out bottom r13 bytes of xmm9 +- pand %xmm1, %xmm9 # mask out bottom r13 bytes of xmm9 +- +- pand %xmm1, %xmm3 +- movdqa SHUF_MASK(%rip), %xmm10 +- pshufb %xmm10, %xmm3 +- pshufb %xmm2, %xmm3 +- pxor %xmm3, \AAD_HASH +- +- test %r10, %r10 +- jl .L_partial_incomplete_1_\@ +- +- # GHASH computation for the last <16 Byte block +- GHASH_MUL \AAD_HASH, %xmm13, %xmm0, %xmm10, %xmm11, %xmm5, %xmm6 +- xor %eax, %eax +- +- mov %rax, PBlockLen(%arg2) +- jmp .L_dec_done_\@ +-.L_partial_incomplete_1_\@: +- add \PLAIN_CYPH_LEN, PBlockLen(%arg2) +-.L_dec_done_\@: +- movdqu \AAD_HASH, AadHash(%arg2) +-.else +- pxor %xmm1, %xmm9 # Plaintext XOR E(K, Yn) +- +- mov \PLAIN_CYPH_LEN, %r10 +- add %r13, %r10 +- # Set r10 to be the amount of data left in CYPH_PLAIN_IN after filling +- sub $16, %r10 +- # Determine if partial block is not being filled and +- # shift mask accordingly +- jge .L_no_extra_mask_2_\@ +- sub %r10, %r12 +-.L_no_extra_mask_2_\@: +- +- movdqu ALL_F-SHIFT_MASK(%r12), %xmm1 +- # get the appropriate mask to mask out bottom r13 bytes of xmm9 +- pand %xmm1, %xmm9 +- +- movdqa SHUF_MASK(%rip), %xmm1 +- pshufb %xmm1, %xmm9 +- pshufb %xmm2, %xmm9 +- pxor %xmm9, \AAD_HASH +- +- test %r10, %r10 +- jl .L_partial_incomplete_2_\@ +- +- # GHASH computation for the last <16 Byte block +- GHASH_MUL \AAD_HASH, %xmm13, %xmm0, %xmm10, %xmm11, %xmm5, %xmm6 +- xor %eax, %eax +- +- mov %rax, PBlockLen(%arg2) +- jmp .L_encode_done_\@ +-.L_partial_incomplete_2_\@: +- add \PLAIN_CYPH_LEN, PBlockLen(%arg2) +-.L_encode_done_\@: +- movdqu \AAD_HASH, AadHash(%arg2) +- +- movdqa SHUF_MASK(%rip), %xmm10 +- # shuffle xmm9 back to output as ciphertext +- pshufb %xmm10, %xmm9 +- pshufb %xmm2, %xmm9 +-.endif +- # output encrypted Bytes +- test %r10, %r10 +- jl .L_partial_fill_\@ +- mov %r13, %r12 +- mov $16, %r13 +- # Set r13 to be the number of bytes to write out +- sub %r12, %r13 +- jmp .L_count_set_\@ +-.L_partial_fill_\@: +- mov \PLAIN_CYPH_LEN, %r13 +-.L_count_set_\@: +- movdqa %xmm9, %xmm0 +- movq %xmm0, %rax +- cmp $8, %r13 +- jle .L_less_than_8_bytes_left_\@ +- +- mov %rax, (\CYPH_PLAIN_OUT, \DATA_OFFSET, 1) +- add $8, \DATA_OFFSET +- psrldq $8, %xmm0 +- movq %xmm0, %rax +- sub $8, %r13 +-.L_less_than_8_bytes_left_\@: +- movb %al, (\CYPH_PLAIN_OUT, \DATA_OFFSET, 1) +- add $1, \DATA_OFFSET +- shr $8, %rax +- sub $1, %r13 +- jne .L_less_than_8_bytes_left_\@ +-.L_partial_block_done_\@: +-.endm # PARTIAL_BLOCK +- +-/* +-* if a = number of total plaintext bytes +-* b = floor(a/16) +-* num_initial_blocks = b mod 4 +-* encrypt the initial num_initial_blocks blocks and apply ghash on +-* the ciphertext +-* %r10, %r11, %r12, %rax, %xmm5, %xmm6, %xmm7, %xmm8, %xmm9 registers +-* are clobbered +-* arg1, %arg2, %arg3 are used as a pointer only, not modified +-*/ +- +- +-.macro INITIAL_BLOCKS_ENC_DEC TMP1 TMP2 TMP3 TMP4 TMP5 XMM0 XMM1 \ +- XMM2 XMM3 XMM4 XMMDst TMP6 TMP7 i i_seq operation +- MOVADQ SHUF_MASK(%rip), %xmm14 +- +- movdqu AadHash(%arg2), %xmm\i # XMM0 = Y0 +- +- # start AES for num_initial_blocks blocks +- +- movdqu CurCount(%arg2), \XMM0 # XMM0 = Y0 +- +-.if (\i == 5) || (\i == 6) || (\i == 7) +- +- MOVADQ ONE(%RIP),\TMP1 +- MOVADQ 0(%arg1),\TMP2 +-.irpc index, \i_seq +- paddd \TMP1, \XMM0 # INCR Y0 +-.ifc \operation, dec +- movdqa \XMM0, %xmm\index +-.else +- MOVADQ \XMM0, %xmm\index +-.endif +- pshufb %xmm14, %xmm\index # perform a 16 byte swap +- pxor \TMP2, %xmm\index +-.endr +- lea 0x10(%arg1),%r10 +- mov keysize,%eax +- shr $2,%eax # 128->4, 192->6, 256->8 +- add $5,%eax # 128->9, 192->11, 256->13 +- +-.Laes_loop_initial_\@: +- MOVADQ (%r10),\TMP1 +-.irpc index, \i_seq +- aesenc \TMP1, %xmm\index +-.endr +- add $16,%r10 +- sub $1,%eax +- jnz .Laes_loop_initial_\@ +- +- MOVADQ (%r10), \TMP1 +-.irpc index, \i_seq +- aesenclast \TMP1, %xmm\index # Last Round +-.endr +-.irpc index, \i_seq +- movdqu (%arg4 , %r11, 1), \TMP1 +- pxor \TMP1, %xmm\index +- movdqu %xmm\index, (%arg3 , %r11, 1) +- # write back plaintext/ciphertext for num_initial_blocks +- add $16, %r11 +- +-.ifc \operation, dec +- movdqa \TMP1, %xmm\index +-.endif +- pshufb %xmm14, %xmm\index +- +- # prepare plaintext/ciphertext for GHASH computation +-.endr +-.endif +- +- # apply GHASH on num_initial_blocks blocks +- +-.if \i == 5 +- pxor %xmm5, %xmm6 +- GHASH_MUL %xmm6, \TMP3, \TMP1, \TMP2, \TMP4, \TMP5, \XMM1 +- pxor %xmm6, %xmm7 +- GHASH_MUL %xmm7, \TMP3, \TMP1, \TMP2, \TMP4, \TMP5, \XMM1 +- pxor %xmm7, %xmm8 +- GHASH_MUL %xmm8, \TMP3, \TMP1, \TMP2, \TMP4, \TMP5, \XMM1 +-.elseif \i == 6 +- pxor %xmm6, %xmm7 +- GHASH_MUL %xmm7, \TMP3, \TMP1, \TMP2, \TMP4, \TMP5, \XMM1 +- pxor %xmm7, %xmm8 +- GHASH_MUL %xmm8, \TMP3, \TMP1, \TMP2, \TMP4, \TMP5, \XMM1 +-.elseif \i == 7 +- pxor %xmm7, %xmm8 +- GHASH_MUL %xmm8, \TMP3, \TMP1, \TMP2, \TMP4, \TMP5, \XMM1 +-.endif +- cmp $64, %r13 +- jl .L_initial_blocks_done\@ +- # no need for precomputed values +-/* +-* +-* Precomputations for HashKey parallel with encryption of first 4 blocks. +-* Haskey_i_k holds XORed values of the low and high parts of the Haskey_i +-*/ +- MOVADQ ONE(%RIP),\TMP1 +- paddd \TMP1, \XMM0 # INCR Y0 +- MOVADQ \XMM0, \XMM1 +- pshufb %xmm14, \XMM1 # perform a 16 byte swap +- +- paddd \TMP1, \XMM0 # INCR Y0 +- MOVADQ \XMM0, \XMM2 +- pshufb %xmm14, \XMM2 # perform a 16 byte swap +- +- paddd \TMP1, \XMM0 # INCR Y0 +- MOVADQ \XMM0, \XMM3 +- pshufb %xmm14, \XMM3 # perform a 16 byte swap +- +- paddd \TMP1, \XMM0 # INCR Y0 +- MOVADQ \XMM0, \XMM4 +- pshufb %xmm14, \XMM4 # perform a 16 byte swap +- +- MOVADQ 0(%arg1),\TMP1 +- pxor \TMP1, \XMM1 +- pxor \TMP1, \XMM2 +- pxor \TMP1, \XMM3 +- pxor \TMP1, \XMM4 +-.irpc index, 1234 # do 4 rounds +- movaps 0x10*\index(%arg1), \TMP1 +- aesenc \TMP1, \XMM1 +- aesenc \TMP1, \XMM2 +- aesenc \TMP1, \XMM3 +- aesenc \TMP1, \XMM4 +-.endr +-.irpc index, 56789 # do next 5 rounds +- movaps 0x10*\index(%arg1), \TMP1 +- aesenc \TMP1, \XMM1 +- aesenc \TMP1, \XMM2 +- aesenc \TMP1, \XMM3 +- aesenc \TMP1, \XMM4 +-.endr +- lea 0xa0(%arg1),%r10 +- mov keysize,%eax +- shr $2,%eax # 128->4, 192->6, 256->8 +- sub $4,%eax # 128->0, 192->2, 256->4 +- jz .Laes_loop_pre_done\@ +- +-.Laes_loop_pre_\@: +- MOVADQ (%r10),\TMP2 +-.irpc index, 1234 +- aesenc \TMP2, %xmm\index +-.endr +- add $16,%r10 +- sub $1,%eax +- jnz .Laes_loop_pre_\@ +- +-.Laes_loop_pre_done\@: +- MOVADQ (%r10), \TMP2 +- aesenclast \TMP2, \XMM1 +- aesenclast \TMP2, \XMM2 +- aesenclast \TMP2, \XMM3 +- aesenclast \TMP2, \XMM4 +- movdqu 16*0(%arg4 , %r11 , 1), \TMP1 +- pxor \TMP1, \XMM1 +-.ifc \operation, dec +- movdqu \XMM1, 16*0(%arg3 , %r11 , 1) +- movdqa \TMP1, \XMM1 +-.endif +- movdqu 16*1(%arg4 , %r11 , 1), \TMP1 +- pxor \TMP1, \XMM2 +-.ifc \operation, dec +- movdqu \XMM2, 16*1(%arg3 , %r11 , 1) +- movdqa \TMP1, \XMM2 +-.endif +- movdqu 16*2(%arg4 , %r11 , 1), \TMP1 +- pxor \TMP1, \XMM3 +-.ifc \operation, dec +- movdqu \XMM3, 16*2(%arg3 , %r11 , 1) +- movdqa \TMP1, \XMM3 +-.endif +- movdqu 16*3(%arg4 , %r11 , 1), \TMP1 +- pxor \TMP1, \XMM4 +-.ifc \operation, dec +- movdqu \XMM4, 16*3(%arg3 , %r11 , 1) +- movdqa \TMP1, \XMM4 +-.else +- movdqu \XMM1, 16*0(%arg3 , %r11 , 1) +- movdqu \XMM2, 16*1(%arg3 , %r11 , 1) +- movdqu \XMM3, 16*2(%arg3 , %r11 , 1) +- movdqu \XMM4, 16*3(%arg3 , %r11 , 1) +-.endif +- +- add $64, %r11 +- pshufb %xmm14, \XMM1 # perform a 16 byte swap +- pxor \XMMDst, \XMM1 +-# combine GHASHed value with the corresponding ciphertext +- pshufb %xmm14, \XMM2 # perform a 16 byte swap +- pshufb %xmm14, \XMM3 # perform a 16 byte swap +- pshufb %xmm14, \XMM4 # perform a 16 byte swap +- +-.L_initial_blocks_done\@: +- +-.endm +- +-/* +-* encrypt 4 blocks at a time +-* ghash the 4 previously encrypted ciphertext blocks +-* arg1, %arg3, %arg4 are used as pointers only, not modified +-* %r11 is the data offset value +-*/ +-.macro GHASH_4_ENCRYPT_4_PARALLEL_enc TMP1 TMP2 TMP3 TMP4 TMP5 \ +-TMP6 XMM0 XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 operation +- +- movdqa \XMM1, \XMM5 +- movdqa \XMM2, \XMM6 +- movdqa \XMM3, \XMM7 +- movdqa \XMM4, \XMM8 +- +- movdqa SHUF_MASK(%rip), %xmm15 +- # multiply TMP5 * HashKey using karatsuba +- +- movdqa \XMM5, \TMP4 +- pshufd $78, \XMM5, \TMP6 +- pxor \XMM5, \TMP6 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqu HashKey_4(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP4 # TMP4 = a1*b1 +- movdqa \XMM0, \XMM1 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqa \XMM0, \XMM2 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqa \XMM0, \XMM3 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqa \XMM0, \XMM4 +- pshufb %xmm15, \XMM1 # perform a 16 byte swap +- pclmulqdq $0x00, \TMP5, \XMM5 # XMM5 = a0*b0 +- pshufb %xmm15, \XMM2 # perform a 16 byte swap +- pshufb %xmm15, \XMM3 # perform a 16 byte swap +- pshufb %xmm15, \XMM4 # perform a 16 byte swap +- +- pxor (%arg1), \XMM1 +- pxor (%arg1), \XMM2 +- pxor (%arg1), \XMM3 +- pxor (%arg1), \XMM4 +- movdqu HashKey_4_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP6 # TMP6 = (a1+a0)*(b1+b0) +- movaps 0x10(%arg1), \TMP1 +- aesenc \TMP1, \XMM1 # Round 1 +- aesenc \TMP1, \XMM2 +- aesenc \TMP1, \XMM3 +- aesenc \TMP1, \XMM4 +- movaps 0x20(%arg1), \TMP1 +- aesenc \TMP1, \XMM1 # Round 2 +- aesenc \TMP1, \XMM2 +- aesenc \TMP1, \XMM3 +- aesenc \TMP1, \XMM4 +- movdqa \XMM6, \TMP1 +- pshufd $78, \XMM6, \TMP2 +- pxor \XMM6, \TMP2 +- movdqu HashKey_3(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1 * b1 +- movaps 0x30(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 3 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pclmulqdq $0x00, \TMP5, \XMM6 # XMM6 = a0*b0 +- movaps 0x40(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 4 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- movdqu HashKey_3_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movaps 0x50(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 5 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pxor \TMP1, \TMP4 +-# accumulate the results in TMP4:XMM5, TMP6 holds the middle part +- pxor \XMM6, \XMM5 +- pxor \TMP2, \TMP6 +- movdqa \XMM7, \TMP1 +- pshufd $78, \XMM7, \TMP2 +- pxor \XMM7, \TMP2 +- movdqu HashKey_2(%arg2), \TMP5 +- +- # Multiply TMP5 * HashKey using karatsuba +- +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- movaps 0x60(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 6 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pclmulqdq $0x00, \TMP5, \XMM7 # XMM7 = a0*b0 +- movaps 0x70(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 7 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- movdqu HashKey_2_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movaps 0x80(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 8 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pxor \TMP1, \TMP4 +-# accumulate the results in TMP4:XMM5, TMP6 holds the middle part +- pxor \XMM7, \XMM5 +- pxor \TMP2, \TMP6 +- +- # Multiply XMM8 * HashKey +- # XMM8 and TMP5 hold the values for the two operands +- +- movdqa \XMM8, \TMP1 +- pshufd $78, \XMM8, \TMP2 +- pxor \XMM8, \TMP2 +- movdqu HashKey(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- movaps 0x90(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 9 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pclmulqdq $0x00, \TMP5, \XMM8 # XMM8 = a0*b0 +- lea 0xa0(%arg1),%r10 +- mov keysize,%eax +- shr $2,%eax # 128->4, 192->6, 256->8 +- sub $4,%eax # 128->0, 192->2, 256->4 +- jz .Laes_loop_par_enc_done\@ +- +-.Laes_loop_par_enc\@: +- MOVADQ (%r10),\TMP3 +-.irpc index, 1234 +- aesenc \TMP3, %xmm\index +-.endr +- add $16,%r10 +- sub $1,%eax +- jnz .Laes_loop_par_enc\@ +- +-.Laes_loop_par_enc_done\@: +- MOVADQ (%r10), \TMP3 +- aesenclast \TMP3, \XMM1 # Round 10 +- aesenclast \TMP3, \XMM2 +- aesenclast \TMP3, \XMM3 +- aesenclast \TMP3, \XMM4 +- movdqu HashKey_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movdqu (%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM1 # Ciphertext/Plaintext XOR EK +- movdqu 16(%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM2 # Ciphertext/Plaintext XOR EK +- movdqu 32(%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM3 # Ciphertext/Plaintext XOR EK +- movdqu 48(%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM4 # Ciphertext/Plaintext XOR EK +- movdqu \XMM1, (%arg3,%r11,1) # Write to the ciphertext buffer +- movdqu \XMM2, 16(%arg3,%r11,1) # Write to the ciphertext buffer +- movdqu \XMM3, 32(%arg3,%r11,1) # Write to the ciphertext buffer +- movdqu \XMM4, 48(%arg3,%r11,1) # Write to the ciphertext buffer +- pshufb %xmm15, \XMM1 # perform a 16 byte swap +- pshufb %xmm15, \XMM2 # perform a 16 byte swap +- pshufb %xmm15, \XMM3 # perform a 16 byte swap +- pshufb %xmm15, \XMM4 # perform a 16 byte swap +- +- pxor \TMP4, \TMP1 +- pxor \XMM8, \XMM5 +- pxor \TMP6, \TMP2 +- pxor \TMP1, \TMP2 +- pxor \XMM5, \TMP2 +- movdqa \TMP2, \TMP3 +- pslldq $8, \TMP3 # left shift TMP3 2 DWs +- psrldq $8, \TMP2 # right shift TMP2 2 DWs +- pxor \TMP3, \XMM5 +- pxor \TMP2, \TMP1 # accumulate the results in TMP1:XMM5 +- +- # first phase of reduction +- +- movdqa \XMM5, \TMP2 +- movdqa \XMM5, \TMP3 +- movdqa \XMM5, \TMP4 +-# move XMM5 into TMP2, TMP3, TMP4 in order to perform shifts independently +- pslld $31, \TMP2 # packed right shift << 31 +- pslld $30, \TMP3 # packed right shift << 30 +- pslld $25, \TMP4 # packed right shift << 25 +- pxor \TMP3, \TMP2 # xor the shifted versions +- pxor \TMP4, \TMP2 +- movdqa \TMP2, \TMP5 +- psrldq $4, \TMP5 # right shift T5 1 DW +- pslldq $12, \TMP2 # left shift T2 3 DWs +- pxor \TMP2, \XMM5 +- +- # second phase of reduction +- +- movdqa \XMM5,\TMP2 # make 3 copies of XMM5 into TMP2, TMP3, TMP4 +- movdqa \XMM5,\TMP3 +- movdqa \XMM5,\TMP4 +- psrld $1, \TMP2 # packed left shift >>1 +- psrld $2, \TMP3 # packed left shift >>2 +- psrld $7, \TMP4 # packed left shift >>7 +- pxor \TMP3,\TMP2 # xor the shifted versions +- pxor \TMP4,\TMP2 +- pxor \TMP5, \TMP2 +- pxor \TMP2, \XMM5 +- pxor \TMP1, \XMM5 # result is in TMP1 +- +- pxor \XMM5, \XMM1 +-.endm +- +-/* +-* decrypt 4 blocks at a time +-* ghash the 4 previously decrypted ciphertext blocks +-* arg1, %arg3, %arg4 are used as pointers only, not modified +-* %r11 is the data offset value +-*/ +-.macro GHASH_4_ENCRYPT_4_PARALLEL_dec TMP1 TMP2 TMP3 TMP4 TMP5 \ +-TMP6 XMM0 XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 operation +- +- movdqa \XMM1, \XMM5 +- movdqa \XMM2, \XMM6 +- movdqa \XMM3, \XMM7 +- movdqa \XMM4, \XMM8 +- +- movdqa SHUF_MASK(%rip), %xmm15 +- # multiply TMP5 * HashKey using karatsuba +- +- movdqa \XMM5, \TMP4 +- pshufd $78, \XMM5, \TMP6 +- pxor \XMM5, \TMP6 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqu HashKey_4(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP4 # TMP4 = a1*b1 +- movdqa \XMM0, \XMM1 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqa \XMM0, \XMM2 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqa \XMM0, \XMM3 +- paddd ONE(%rip), \XMM0 # INCR CNT +- movdqa \XMM0, \XMM4 +- pshufb %xmm15, \XMM1 # perform a 16 byte swap +- pclmulqdq $0x00, \TMP5, \XMM5 # XMM5 = a0*b0 +- pshufb %xmm15, \XMM2 # perform a 16 byte swap +- pshufb %xmm15, \XMM3 # perform a 16 byte swap +- pshufb %xmm15, \XMM4 # perform a 16 byte swap +- +- pxor (%arg1), \XMM1 +- pxor (%arg1), \XMM2 +- pxor (%arg1), \XMM3 +- pxor (%arg1), \XMM4 +- movdqu HashKey_4_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP6 # TMP6 = (a1+a0)*(b1+b0) +- movaps 0x10(%arg1), \TMP1 +- aesenc \TMP1, \XMM1 # Round 1 +- aesenc \TMP1, \XMM2 +- aesenc \TMP1, \XMM3 +- aesenc \TMP1, \XMM4 +- movaps 0x20(%arg1), \TMP1 +- aesenc \TMP1, \XMM1 # Round 2 +- aesenc \TMP1, \XMM2 +- aesenc \TMP1, \XMM3 +- aesenc \TMP1, \XMM4 +- movdqa \XMM6, \TMP1 +- pshufd $78, \XMM6, \TMP2 +- pxor \XMM6, \TMP2 +- movdqu HashKey_3(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1 * b1 +- movaps 0x30(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 3 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pclmulqdq $0x00, \TMP5, \XMM6 # XMM6 = a0*b0 +- movaps 0x40(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 4 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- movdqu HashKey_3_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movaps 0x50(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 5 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pxor \TMP1, \TMP4 +-# accumulate the results in TMP4:XMM5, TMP6 holds the middle part +- pxor \XMM6, \XMM5 +- pxor \TMP2, \TMP6 +- movdqa \XMM7, \TMP1 +- pshufd $78, \XMM7, \TMP2 +- pxor \XMM7, \TMP2 +- movdqu HashKey_2(%arg2), \TMP5 +- +- # Multiply TMP5 * HashKey using karatsuba +- +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- movaps 0x60(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 6 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pclmulqdq $0x00, \TMP5, \XMM7 # XMM7 = a0*b0 +- movaps 0x70(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 7 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- movdqu HashKey_2_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movaps 0x80(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 8 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pxor \TMP1, \TMP4 +-# accumulate the results in TMP4:XMM5, TMP6 holds the middle part +- pxor \XMM7, \XMM5 +- pxor \TMP2, \TMP6 +- +- # Multiply XMM8 * HashKey +- # XMM8 and TMP5 hold the values for the two operands +- +- movdqa \XMM8, \TMP1 +- pshufd $78, \XMM8, \TMP2 +- pxor \XMM8, \TMP2 +- movdqu HashKey(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- movaps 0x90(%arg1), \TMP3 +- aesenc \TMP3, \XMM1 # Round 9 +- aesenc \TMP3, \XMM2 +- aesenc \TMP3, \XMM3 +- aesenc \TMP3, \XMM4 +- pclmulqdq $0x00, \TMP5, \XMM8 # XMM8 = a0*b0 +- lea 0xa0(%arg1),%r10 +- mov keysize,%eax +- shr $2,%eax # 128->4, 192->6, 256->8 +- sub $4,%eax # 128->0, 192->2, 256->4 +- jz .Laes_loop_par_dec_done\@ +- +-.Laes_loop_par_dec\@: +- MOVADQ (%r10),\TMP3 +-.irpc index, 1234 +- aesenc \TMP3, %xmm\index +-.endr +- add $16,%r10 +- sub $1,%eax +- jnz .Laes_loop_par_dec\@ +- +-.Laes_loop_par_dec_done\@: +- MOVADQ (%r10), \TMP3 +- aesenclast \TMP3, \XMM1 # last round +- aesenclast \TMP3, \XMM2 +- aesenclast \TMP3, \XMM3 +- aesenclast \TMP3, \XMM4 +- movdqu HashKey_k(%arg2), \TMP5 +- pclmulqdq $0x00, \TMP5, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movdqu (%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM1 # Ciphertext/Plaintext XOR EK +- movdqu \XMM1, (%arg3,%r11,1) # Write to plaintext buffer +- movdqa \TMP3, \XMM1 +- movdqu 16(%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM2 # Ciphertext/Plaintext XOR EK +- movdqu \XMM2, 16(%arg3,%r11,1) # Write to plaintext buffer +- movdqa \TMP3, \XMM2 +- movdqu 32(%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM3 # Ciphertext/Plaintext XOR EK +- movdqu \XMM3, 32(%arg3,%r11,1) # Write to plaintext buffer +- movdqa \TMP3, \XMM3 +- movdqu 48(%arg4,%r11,1), \TMP3 +- pxor \TMP3, \XMM4 # Ciphertext/Plaintext XOR EK +- movdqu \XMM4, 48(%arg3,%r11,1) # Write to plaintext buffer +- movdqa \TMP3, \XMM4 +- pshufb %xmm15, \XMM1 # perform a 16 byte swap +- pshufb %xmm15, \XMM2 # perform a 16 byte swap +- pshufb %xmm15, \XMM3 # perform a 16 byte swap +- pshufb %xmm15, \XMM4 # perform a 16 byte swap +- +- pxor \TMP4, \TMP1 +- pxor \XMM8, \XMM5 +- pxor \TMP6, \TMP2 +- pxor \TMP1, \TMP2 +- pxor \XMM5, \TMP2 +- movdqa \TMP2, \TMP3 +- pslldq $8, \TMP3 # left shift TMP3 2 DWs +- psrldq $8, \TMP2 # right shift TMP2 2 DWs +- pxor \TMP3, \XMM5 +- pxor \TMP2, \TMP1 # accumulate the results in TMP1:XMM5 +- +- # first phase of reduction +- +- movdqa \XMM5, \TMP2 +- movdqa \XMM5, \TMP3 +- movdqa \XMM5, \TMP4 +-# move XMM5 into TMP2, TMP3, TMP4 in order to perform shifts independently +- pslld $31, \TMP2 # packed right shift << 31 +- pslld $30, \TMP3 # packed right shift << 30 +- pslld $25, \TMP4 # packed right shift << 25 +- pxor \TMP3, \TMP2 # xor the shifted versions +- pxor \TMP4, \TMP2 +- movdqa \TMP2, \TMP5 +- psrldq $4, \TMP5 # right shift T5 1 DW +- pslldq $12, \TMP2 # left shift T2 3 DWs +- pxor \TMP2, \XMM5 +- +- # second phase of reduction +- +- movdqa \XMM5,\TMP2 # make 3 copies of XMM5 into TMP2, TMP3, TMP4 +- movdqa \XMM5,\TMP3 +- movdqa \XMM5,\TMP4 +- psrld $1, \TMP2 # packed left shift >>1 +- psrld $2, \TMP3 # packed left shift >>2 +- psrld $7, \TMP4 # packed left shift >>7 +- pxor \TMP3,\TMP2 # xor the shifted versions +- pxor \TMP4,\TMP2 +- pxor \TMP5, \TMP2 +- pxor \TMP2, \XMM5 +- pxor \TMP1, \XMM5 # result is in TMP1 +- +- pxor \XMM5, \XMM1 +-.endm +- +-/* GHASH the last 4 ciphertext blocks. */ +-.macro GHASH_LAST_4 TMP1 TMP2 TMP3 TMP4 TMP5 TMP6 \ +-TMP7 XMM1 XMM2 XMM3 XMM4 XMMDst +- +- # Multiply TMP6 * HashKey (using Karatsuba) +- +- movdqa \XMM1, \TMP6 +- pshufd $78, \XMM1, \TMP2 +- pxor \XMM1, \TMP2 +- movdqu HashKey_4(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP6 # TMP6 = a1*b1 +- pclmulqdq $0x00, \TMP5, \XMM1 # XMM1 = a0*b0 +- movdqu HashKey_4_k(%arg2), \TMP4 +- pclmulqdq $0x00, \TMP4, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- movdqa \XMM1, \XMMDst +- movdqa \TMP2, \XMM1 # result in TMP6, XMMDst, XMM1 +- +- # Multiply TMP1 * HashKey (using Karatsuba) +- +- movdqa \XMM2, \TMP1 +- pshufd $78, \XMM2, \TMP2 +- pxor \XMM2, \TMP2 +- movdqu HashKey_3(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- pclmulqdq $0x00, \TMP5, \XMM2 # XMM2 = a0*b0 +- movdqu HashKey_3_k(%arg2), \TMP4 +- pclmulqdq $0x00, \TMP4, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- pxor \TMP1, \TMP6 +- pxor \XMM2, \XMMDst +- pxor \TMP2, \XMM1 +-# results accumulated in TMP6, XMMDst, XMM1 +- +- # Multiply TMP1 * HashKey (using Karatsuba) +- +- movdqa \XMM3, \TMP1 +- pshufd $78, \XMM3, \TMP2 +- pxor \XMM3, \TMP2 +- movdqu HashKey_2(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- pclmulqdq $0x00, \TMP5, \XMM3 # XMM3 = a0*b0 +- movdqu HashKey_2_k(%arg2), \TMP4 +- pclmulqdq $0x00, \TMP4, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- pxor \TMP1, \TMP6 +- pxor \XMM3, \XMMDst +- pxor \TMP2, \XMM1 # results accumulated in TMP6, XMMDst, XMM1 +- +- # Multiply TMP1 * HashKey (using Karatsuba) +- movdqa \XMM4, \TMP1 +- pshufd $78, \XMM4, \TMP2 +- pxor \XMM4, \TMP2 +- movdqu HashKey(%arg2), \TMP5 +- pclmulqdq $0x11, \TMP5, \TMP1 # TMP1 = a1*b1 +- pclmulqdq $0x00, \TMP5, \XMM4 # XMM4 = a0*b0 +- movdqu HashKey_k(%arg2), \TMP4 +- pclmulqdq $0x00, \TMP4, \TMP2 # TMP2 = (a1+a0)*(b1+b0) +- pxor \TMP1, \TMP6 +- pxor \XMM4, \XMMDst +- pxor \XMM1, \TMP2 +- pxor \TMP6, \TMP2 +- pxor \XMMDst, \TMP2 +- # middle section of the temp results combined as in karatsuba algorithm +- movdqa \TMP2, \TMP4 +- pslldq $8, \TMP4 # left shift TMP4 2 DWs +- psrldq $8, \TMP2 # right shift TMP2 2 DWs +- pxor \TMP4, \XMMDst +- pxor \TMP2, \TMP6 +-# TMP6:XMMDst holds the result of the accumulated carry-less multiplications +- # first phase of the reduction +- movdqa \XMMDst, \TMP2 +- movdqa \XMMDst, \TMP3 +- movdqa \XMMDst, \TMP4 +-# move XMMDst into TMP2, TMP3, TMP4 in order to perform 3 shifts independently +- pslld $31, \TMP2 # packed right shifting << 31 +- pslld $30, \TMP3 # packed right shifting << 30 +- pslld $25, \TMP4 # packed right shifting << 25 +- pxor \TMP3, \TMP2 # xor the shifted versions +- pxor \TMP4, \TMP2 +- movdqa \TMP2, \TMP7 +- psrldq $4, \TMP7 # right shift TMP7 1 DW +- pslldq $12, \TMP2 # left shift TMP2 3 DWs +- pxor \TMP2, \XMMDst +- +- # second phase of the reduction +- movdqa \XMMDst, \TMP2 +- # make 3 copies of XMMDst for doing 3 shift operations +- movdqa \XMMDst, \TMP3 +- movdqa \XMMDst, \TMP4 +- psrld $1, \TMP2 # packed left shift >> 1 +- psrld $2, \TMP3 # packed left shift >> 2 +- psrld $7, \TMP4 # packed left shift >> 7 +- pxor \TMP3, \TMP2 # xor the shifted versions +- pxor \TMP4, \TMP2 +- pxor \TMP7, \TMP2 +- pxor \TMP2, \XMMDst +- pxor \TMP6, \XMMDst # reduced result is in XMMDst +-.endm +- +- +-/* Encryption of a single block +-* uses eax & r10 +-*/ +- +-.macro ENCRYPT_SINGLE_BLOCK XMM0 TMP1 +- +- pxor (%arg1), \XMM0 +- mov keysize,%eax +- shr $2,%eax # 128->4, 192->6, 256->8 +- add $5,%eax # 128->9, 192->11, 256->13 +- lea 16(%arg1), %r10 # get first expanded key address +- +-_esb_loop_\@: +- MOVADQ (%r10),\TMP1 +- aesenc \TMP1,\XMM0 +- add $16,%r10 +- sub $1,%eax +- jnz _esb_loop_\@ +- +- MOVADQ (%r10),\TMP1 +- aesenclast \TMP1,\XMM0 +-.endm +- +-/***************************************************************************** +-* void aesni_gcm_init(void *aes_ctx, // AES Key schedule. Starts on a 16 byte boundary. +-* struct gcm_context_data *data, +-* // context data +-* u8 *iv, // Pre-counter block j0: 4 byte salt (from Security Association) +-* // concatenated with 8 byte Initialisation Vector (from IPSec ESP Payload) +-* // concatenated with 0x00000001. 16-byte aligned pointer. +-* u8 *hash_subkey, // H, the Hash sub key input. Data starts on a 16-byte boundary. +-* const u8 *aad, // Additional Authentication Data (AAD) +-* u64 aad_len) // Length of AAD in bytes. +-*/ +-SYM_FUNC_START(aesni_gcm_init) +- FUNC_SAVE +- GCM_INIT %arg3, %arg4,%arg5, %arg6 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_init) +- +-/***************************************************************************** +-* void aesni_gcm_enc_update(void *aes_ctx, // AES Key schedule. Starts on a 16 byte boundary. +-* struct gcm_context_data *data, +-* // context data +-* u8 *out, // Ciphertext output. Encrypt in-place is allowed. +-* const u8 *in, // Plaintext input +-* u64 plaintext_len, // Length of data in bytes for encryption. +-*/ +-SYM_FUNC_START(aesni_gcm_enc_update) +- FUNC_SAVE +- GCM_ENC_DEC enc +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_enc_update) +- +-/***************************************************************************** +-* void aesni_gcm_dec_update(void *aes_ctx, // AES Key schedule. Starts on a 16 byte boundary. +-* struct gcm_context_data *data, +-* // context data +-* u8 *out, // Ciphertext output. Encrypt in-place is allowed. +-* const u8 *in, // Plaintext input +-* u64 plaintext_len, // Length of data in bytes for encryption. +-*/ +-SYM_FUNC_START(aesni_gcm_dec_update) +- FUNC_SAVE +- GCM_ENC_DEC dec +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_dec_update) +- +-/***************************************************************************** +-* void aesni_gcm_finalize(void *aes_ctx, // AES Key schedule. Starts on a 16 byte boundary. +-* struct gcm_context_data *data, +-* // context data +-* u8 *auth_tag, // Authenticated Tag output. +-* u64 auth_tag_len); // Authenticated Tag Length in bytes. Valid values are 16 (most likely), +-* // 12 or 8. +-*/ +-SYM_FUNC_START(aesni_gcm_finalize) +- FUNC_SAVE +- GCM_COMPLETE %arg3 %arg4 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_finalize) +- +-#endif +- + SYM_FUNC_START_LOCAL(_key_expansion_256a) + pshufd $0b11111111, %xmm1, %xmm1 + shufps $0b00010000, %xmm0, %xmm4 +diff --git a/arch/x86/crypto/aesni-intel_avx-x86_64.S b/arch/x86/crypto/aesni-intel_avx-x86_64.S +deleted file mode 100644 +index 8c9749ed0651..000000000000 +--- a/arch/x86/crypto/aesni-intel_avx-x86_64.S ++++ /dev/null +@@ -1,2804 +0,0 @@ +-######################################################################## +-# Copyright (c) 2013, Intel Corporation +-# +-# This software is available to you under a choice of one of two +-# licenses. You may choose to be licensed under the terms of the GNU +-# General Public License (GPL) Version 2, available from the file +-# COPYING in the main directory of this source tree, or the +-# OpenIB.org BSD license below: +-# +-# Redistribution and use in source and binary forms, with or without +-# modification, are permitted provided that the following conditions are +-# met: +-# +-# * Redistributions of source code must retain the above copyright +-# notice, this list of conditions and the following disclaimer. +-# +-# * Redistributions in binary form must reproduce the above copyright +-# notice, this list of conditions and the following disclaimer in the +-# documentation and/or other materials provided with the +-# distribution. +-# +-# * Neither the name of the Intel Corporation nor the names of its +-# contributors may be used to endorse or promote products derived from +-# this software without specific prior written permission. +-# +-# +-# THIS SOFTWARE IS PROVIDED BY INTEL CORPORATION ""AS IS"" AND ANY +-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL CORPORATION OR +-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES# LOSS OF USE, DATA, OR +-# PROFITS# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +-# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +-# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +-# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +-######################################################################## +-## +-## Authors: +-## Erdinc Ozturk +-## Vinodh Gopal +-## James Guilford +-## Tim Chen +-## +-## References: +-## This code was derived and highly optimized from the code described in paper: +-## Vinodh Gopal et. al. Optimized Galois-Counter-Mode Implementation +-## on Intel Architecture Processors. August, 2010 +-## The details of the implementation is explained in: +-## Erdinc Ozturk et. al. Enabling High-Performance Galois-Counter-Mode +-## on Intel Architecture Processors. October, 2012. +-## +-## Assumptions: +-## +-## +-## +-## iv: +-## 0 1 2 3 +-## 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | Salt (From the SA) | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | Initialization Vector | +-## | (This is the sequence number from IPSec header) | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | 0x1 | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## +-## +-## +-## AAD: +-## AAD padded to 128 bits with 0 +-## for example, assume AAD is a u32 vector +-## +-## if AAD is 8 bytes: +-## AAD[3] = {A0, A1}# +-## padded AAD in xmm register = {A1 A0 0 0} +-## +-## 0 1 2 3 +-## 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | SPI (A1) | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | 32-bit Sequence Number (A0) | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | 0x0 | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## +-## AAD Format with 32-bit Sequence Number +-## +-## if AAD is 12 bytes: +-## AAD[3] = {A0, A1, A2}# +-## padded AAD in xmm register = {A2 A1 A0 0} +-## +-## 0 1 2 3 +-## 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | SPI (A2) | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | 64-bit Extended Sequence Number {A1,A0} | +-## | | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## | 0x0 | +-## +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-## +-## AAD Format with 64-bit Extended Sequence Number +-## +-## +-## aadLen: +-## from the definition of the spec, aadLen can only be 8 or 12 bytes. +-## The code additionally supports aadLen of length 16 bytes. +-## +-## TLen: +-## from the definition of the spec, TLen can only be 8, 12 or 16 bytes. +-## +-## poly = x^128 + x^127 + x^126 + x^121 + 1 +-## throughout the code, one tab and two tab indentations are used. one tab is +-## for GHASH part, two tabs is for AES part. +-## +- +-#include +- +-# constants in mergeable sections, linker can reorder and merge +-.section .rodata.cst16.POLY, "aM", @progbits, 16 +-.align 16 +-POLY: .octa 0xC2000000000000000000000000000001 +- +-.section .rodata.cst16.POLY2, "aM", @progbits, 16 +-.align 16 +-POLY2: .octa 0xC20000000000000000000001C2000000 +- +-.section .rodata.cst16.TWOONE, "aM", @progbits, 16 +-.align 16 +-TWOONE: .octa 0x00000001000000000000000000000001 +- +-.section .rodata.cst16.SHUF_MASK, "aM", @progbits, 16 +-.align 16 +-SHUF_MASK: .octa 0x000102030405060708090A0B0C0D0E0F +- +-.section .rodata.cst16.ONE, "aM", @progbits, 16 +-.align 16 +-ONE: .octa 0x00000000000000000000000000000001 +- +-.section .rodata.cst16.ONEf, "aM", @progbits, 16 +-.align 16 +-ONEf: .octa 0x01000000000000000000000000000000 +- +-# order of these constants should not change. +-# more specifically, ALL_F should follow SHIFT_MASK, and zero should follow ALL_F +-.section .rodata, "a", @progbits +-.align 16 +-SHIFT_MASK: .octa 0x0f0e0d0c0b0a09080706050403020100 +-ALL_F: .octa 0xffffffffffffffffffffffffffffffff +- .octa 0x00000000000000000000000000000000 +- +-.text +- +- +-#define AadHash 16*0 +-#define AadLen 16*1 +-#define InLen (16*1)+8 +-#define PBlockEncKey 16*2 +-#define OrigIV 16*3 +-#define CurCount 16*4 +-#define PBlockLen 16*5 +- +-HashKey = 16*6 # store HashKey <<1 mod poly here +-HashKey_2 = 16*7 # store HashKey^2 <<1 mod poly here +-HashKey_3 = 16*8 # store HashKey^3 <<1 mod poly here +-HashKey_4 = 16*9 # store HashKey^4 <<1 mod poly here +-HashKey_5 = 16*10 # store HashKey^5 <<1 mod poly here +-HashKey_6 = 16*11 # store HashKey^6 <<1 mod poly here +-HashKey_7 = 16*12 # store HashKey^7 <<1 mod poly here +-HashKey_8 = 16*13 # store HashKey^8 <<1 mod poly here +-HashKey_k = 16*14 # store XOR of HashKey <<1 mod poly here (for Karatsuba purposes) +-HashKey_2_k = 16*15 # store XOR of HashKey^2 <<1 mod poly here (for Karatsuba purposes) +-HashKey_3_k = 16*16 # store XOR of HashKey^3 <<1 mod poly here (for Karatsuba purposes) +-HashKey_4_k = 16*17 # store XOR of HashKey^4 <<1 mod poly here (for Karatsuba purposes) +-HashKey_5_k = 16*18 # store XOR of HashKey^5 <<1 mod poly here (for Karatsuba purposes) +-HashKey_6_k = 16*19 # store XOR of HashKey^6 <<1 mod poly here (for Karatsuba purposes) +-HashKey_7_k = 16*20 # store XOR of HashKey^7 <<1 mod poly here (for Karatsuba purposes) +-HashKey_8_k = 16*21 # store XOR of HashKey^8 <<1 mod poly here (for Karatsuba purposes) +- +-#define arg1 %rdi +-#define arg2 %rsi +-#define arg3 %rdx +-#define arg4 %rcx +-#define arg5 %r8 +-#define arg6 %r9 +-#define keysize 2*15*16(arg1) +- +-i = 0 +-j = 0 +- +-out_order = 0 +-in_order = 1 +-DEC = 0 +-ENC = 1 +- +-.macro define_reg r n +-reg_\r = %xmm\n +-.endm +- +-.macro setreg +-.altmacro +-define_reg i %i +-define_reg j %j +-.noaltmacro +-.endm +- +-TMP1 = 16*0 # Temporary storage for AAD +-TMP2 = 16*1 # Temporary storage for AES State 2 (State 1 is stored in an XMM register) +-TMP3 = 16*2 # Temporary storage for AES State 3 +-TMP4 = 16*3 # Temporary storage for AES State 4 +-TMP5 = 16*4 # Temporary storage for AES State 5 +-TMP6 = 16*5 # Temporary storage for AES State 6 +-TMP7 = 16*6 # Temporary storage for AES State 7 +-TMP8 = 16*7 # Temporary storage for AES State 8 +- +-VARIABLE_OFFSET = 16*8 +- +-################################ +-# Utility Macros +-################################ +- +-.macro FUNC_SAVE +- push %r12 +- push %r13 +- push %r15 +- +- push %rbp +- mov %rsp, %rbp +- +- sub $VARIABLE_OFFSET, %rsp +- and $~63, %rsp # align rsp to 64 bytes +-.endm +- +-.macro FUNC_RESTORE +- mov %rbp, %rsp +- pop %rbp +- +- pop %r15 +- pop %r13 +- pop %r12 +-.endm +- +-# Encryption of a single block +-.macro ENCRYPT_SINGLE_BLOCK REP XMM0 +- vpxor (arg1), \XMM0, \XMM0 +- i = 1 +- setreg +-.rep \REP +- vaesenc 16*i(arg1), \XMM0, \XMM0 +- i = (i+1) +- setreg +-.endr +- vaesenclast 16*i(arg1), \XMM0, \XMM0 +-.endm +- +-# combined for GCM encrypt and decrypt functions +-# clobbering all xmm registers +-# clobbering r10, r11, r12, r13, r15, rax +-.macro GCM_ENC_DEC INITIAL_BLOCKS GHASH_8_ENCRYPT_8_PARALLEL GHASH_LAST_8 GHASH_MUL ENC_DEC REP +- vmovdqu AadHash(arg2), %xmm8 +- vmovdqu HashKey(arg2), %xmm13 # xmm13 = HashKey +- add arg5, InLen(arg2) +- +- # initialize the data pointer offset as zero +- xor %r11d, %r11d +- +- PARTIAL_BLOCK \GHASH_MUL, arg3, arg4, arg5, %r11, %xmm8, \ENC_DEC +- sub %r11, arg5 +- +- mov arg5, %r13 # save the number of bytes of plaintext/ciphertext +- and $-16, %r13 # r13 = r13 - (r13 mod 16) +- +- mov %r13, %r12 +- shr $4, %r12 +- and $7, %r12 +- jz .L_initial_num_blocks_is_0\@ +- +- cmp $7, %r12 +- je .L_initial_num_blocks_is_7\@ +- cmp $6, %r12 +- je .L_initial_num_blocks_is_6\@ +- cmp $5, %r12 +- je .L_initial_num_blocks_is_5\@ +- cmp $4, %r12 +- je .L_initial_num_blocks_is_4\@ +- cmp $3, %r12 +- je .L_initial_num_blocks_is_3\@ +- cmp $2, %r12 +- je .L_initial_num_blocks_is_2\@ +- +- jmp .L_initial_num_blocks_is_1\@ +- +-.L_initial_num_blocks_is_7\@: +- \INITIAL_BLOCKS \REP, 7, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*7, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_6\@: +- \INITIAL_BLOCKS \REP, 6, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*6, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_5\@: +- \INITIAL_BLOCKS \REP, 5, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*5, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_4\@: +- \INITIAL_BLOCKS \REP, 4, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*4, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_3\@: +- \INITIAL_BLOCKS \REP, 3, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*3, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_2\@: +- \INITIAL_BLOCKS \REP, 2, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*2, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_1\@: +- \INITIAL_BLOCKS \REP, 1, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- sub $16*1, %r13 +- jmp .L_initial_blocks_encrypted\@ +- +-.L_initial_num_blocks_is_0\@: +- \INITIAL_BLOCKS \REP, 0, %xmm12, %xmm13, %xmm14, %xmm15, %xmm11, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm10, %xmm0, \ENC_DEC +- +- +-.L_initial_blocks_encrypted\@: +- test %r13, %r13 +- je .L_zero_cipher_left\@ +- +- sub $128, %r13 +- je .L_eight_cipher_left\@ +- +- +- +- +- vmovd %xmm9, %r15d +- and $255, %r15d +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- +- +-.L_encrypt_by_8_new\@: +- cmp $(255-8), %r15d +- jg .L_encrypt_by_8\@ +- +- +- +- add $8, %r15b +- \GHASH_8_ENCRYPT_8_PARALLEL \REP, %xmm0, %xmm10, %xmm11, %xmm12, %xmm13, %xmm14, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm15, out_order, \ENC_DEC +- add $128, %r11 +- sub $128, %r13 +- jne .L_encrypt_by_8_new\@ +- +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- jmp .L_eight_cipher_left\@ +- +-.L_encrypt_by_8\@: +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- add $8, %r15b +- \GHASH_8_ENCRYPT_8_PARALLEL \REP, %xmm0, %xmm10, %xmm11, %xmm12, %xmm13, %xmm14, %xmm9, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8, %xmm15, in_order, \ENC_DEC +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- add $128, %r11 +- sub $128, %r13 +- jne .L_encrypt_by_8_new\@ +- +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- +- +- +- +-.L_eight_cipher_left\@: +- \GHASH_LAST_8 %xmm0, %xmm10, %xmm11, %xmm12, %xmm13, %xmm14, %xmm15, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5, %xmm6, %xmm7, %xmm8 +- +- +-.L_zero_cipher_left\@: +- vmovdqu %xmm14, AadHash(arg2) +- vmovdqu %xmm9, CurCount(arg2) +- +- # check for 0 length +- mov arg5, %r13 +- and $15, %r13 # r13 = (arg5 mod 16) +- +- je .L_multiple_of_16_bytes\@ +- +- # handle the last <16 Byte block separately +- +- mov %r13, PBlockLen(arg2) +- +- vpaddd ONE(%rip), %xmm9, %xmm9 # INCR CNT to get Yn +- vmovdqu %xmm9, CurCount(arg2) +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- +- ENCRYPT_SINGLE_BLOCK \REP, %xmm9 # E(K, Yn) +- vmovdqu %xmm9, PBlockEncKey(arg2) +- +- cmp $16, arg5 +- jge .L_large_enough_update\@ +- +- lea (arg4,%r11,1), %r10 +- mov %r13, %r12 +- +- READ_PARTIAL_BLOCK %r10 %r12 %xmm1 +- +- lea SHIFT_MASK+16(%rip), %r12 +- sub %r13, %r12 # adjust the shuffle mask pointer to be +- # able to shift 16-r13 bytes (r13 is the +- # number of bytes in plaintext mod 16) +- +- jmp .L_final_ghash_mul\@ +- +-.L_large_enough_update\@: +- sub $16, %r11 +- add %r13, %r11 +- +- # receive the last <16 Byte block +- vmovdqu (arg4, %r11, 1), %xmm1 +- +- sub %r13, %r11 +- add $16, %r11 +- +- lea SHIFT_MASK+16(%rip), %r12 +- # adjust the shuffle mask pointer to be able to shift 16-r13 bytes +- # (r13 is the number of bytes in plaintext mod 16) +- sub %r13, %r12 +- # get the appropriate shuffle mask +- vmovdqu (%r12), %xmm2 +- # shift right 16-r13 bytes +- vpshufb %xmm2, %xmm1, %xmm1 +- +-.L_final_ghash_mul\@: +- .if \ENC_DEC == DEC +- vmovdqa %xmm1, %xmm2 +- vpxor %xmm1, %xmm9, %xmm9 # Plaintext XOR E(K, Yn) +- vmovdqu ALL_F-SHIFT_MASK(%r12), %xmm1 # get the appropriate mask to +- # mask out top 16-r13 bytes of xmm9 +- vpand %xmm1, %xmm9, %xmm9 # mask out top 16-r13 bytes of xmm9 +- vpand %xmm1, %xmm2, %xmm2 +- vpshufb SHUF_MASK(%rip), %xmm2, %xmm2 +- vpxor %xmm2, %xmm14, %xmm14 +- +- vmovdqu %xmm14, AadHash(arg2) +- .else +- vpxor %xmm1, %xmm9, %xmm9 # Plaintext XOR E(K, Yn) +- vmovdqu ALL_F-SHIFT_MASK(%r12), %xmm1 # get the appropriate mask to +- # mask out top 16-r13 bytes of xmm9 +- vpand %xmm1, %xmm9, %xmm9 # mask out top 16-r13 bytes of xmm9 +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 +- vpxor %xmm9, %xmm14, %xmm14 +- +- vmovdqu %xmm14, AadHash(arg2) +- vpshufb SHUF_MASK(%rip), %xmm9, %xmm9 # shuffle xmm9 back to output as ciphertext +- .endif +- +- +- ############################# +- # output r13 Bytes +- vmovq %xmm9, %rax +- cmp $8, %r13 +- jle .L_less_than_8_bytes_left\@ +- +- mov %rax, (arg3 , %r11) +- add $8, %r11 +- vpsrldq $8, %xmm9, %xmm9 +- vmovq %xmm9, %rax +- sub $8, %r13 +- +-.L_less_than_8_bytes_left\@: +- movb %al, (arg3 , %r11) +- add $1, %r11 +- shr $8, %rax +- sub $1, %r13 +- jne .L_less_than_8_bytes_left\@ +- ############################# +- +-.L_multiple_of_16_bytes\@: +-.endm +- +- +-# GCM_COMPLETE Finishes update of tag of last partial block +-# Output: Authorization Tag (AUTH_TAG) +-# Clobbers rax, r10-r12, and xmm0, xmm1, xmm5-xmm15 +-.macro GCM_COMPLETE GHASH_MUL REP AUTH_TAG AUTH_TAG_LEN +- vmovdqu AadHash(arg2), %xmm14 +- vmovdqu HashKey(arg2), %xmm13 +- +- mov PBlockLen(arg2), %r12 +- test %r12, %r12 +- je .L_partial_done\@ +- +- #GHASH computation for the last <16 Byte block +- \GHASH_MUL %xmm14, %xmm13, %xmm0, %xmm10, %xmm11, %xmm5, %xmm6 +- +-.L_partial_done\@: +- mov AadLen(arg2), %r12 # r12 = aadLen (number of bytes) +- shl $3, %r12 # convert into number of bits +- vmovd %r12d, %xmm15 # len(A) in xmm15 +- +- mov InLen(arg2), %r12 +- shl $3, %r12 # len(C) in bits (*128) +- vmovq %r12, %xmm1 +- vpslldq $8, %xmm15, %xmm15 # xmm15 = len(A)|| 0x0000000000000000 +- vpxor %xmm1, %xmm15, %xmm15 # xmm15 = len(A)||len(C) +- +- vpxor %xmm15, %xmm14, %xmm14 +- \GHASH_MUL %xmm14, %xmm13, %xmm0, %xmm10, %xmm11, %xmm5, %xmm6 # final GHASH computation +- vpshufb SHUF_MASK(%rip), %xmm14, %xmm14 # perform a 16Byte swap +- +- vmovdqu OrigIV(arg2), %xmm9 +- +- ENCRYPT_SINGLE_BLOCK \REP, %xmm9 # E(K, Y0) +- +- vpxor %xmm14, %xmm9, %xmm9 +- +- +- +-.L_return_T\@: +- mov \AUTH_TAG, %r10 # r10 = authTag +- mov \AUTH_TAG_LEN, %r11 # r11 = auth_tag_len +- +- cmp $16, %r11 +- je .L_T_16\@ +- +- cmp $8, %r11 +- jl .L_T_4\@ +- +-.L_T_8\@: +- vmovq %xmm9, %rax +- mov %rax, (%r10) +- add $8, %r10 +- sub $8, %r11 +- vpsrldq $8, %xmm9, %xmm9 +- test %r11, %r11 +- je .L_return_T_done\@ +-.L_T_4\@: +- vmovd %xmm9, %eax +- mov %eax, (%r10) +- add $4, %r10 +- sub $4, %r11 +- vpsrldq $4, %xmm9, %xmm9 +- test %r11, %r11 +- je .L_return_T_done\@ +-.L_T_123\@: +- vmovd %xmm9, %eax +- cmp $2, %r11 +- jl .L_T_1\@ +- mov %ax, (%r10) +- cmp $2, %r11 +- je .L_return_T_done\@ +- add $2, %r10 +- sar $16, %eax +-.L_T_1\@: +- mov %al, (%r10) +- jmp .L_return_T_done\@ +- +-.L_T_16\@: +- vmovdqu %xmm9, (%r10) +- +-.L_return_T_done\@: +-.endm +- +-.macro CALC_AAD_HASH GHASH_MUL AAD AADLEN T1 T2 T3 T4 T5 T6 T7 T8 +- +- mov \AAD, %r10 # r10 = AAD +- mov \AADLEN, %r12 # r12 = aadLen +- +- +- mov %r12, %r11 +- +- vpxor \T8, \T8, \T8 +- vpxor \T7, \T7, \T7 +- cmp $16, %r11 +- jl .L_get_AAD_rest8\@ +-.L_get_AAD_blocks\@: +- vmovdqu (%r10), \T7 +- vpshufb SHUF_MASK(%rip), \T7, \T7 +- vpxor \T7, \T8, \T8 +- \GHASH_MUL \T8, \T2, \T1, \T3, \T4, \T5, \T6 +- add $16, %r10 +- sub $16, %r12 +- sub $16, %r11 +- cmp $16, %r11 +- jge .L_get_AAD_blocks\@ +- vmovdqu \T8, \T7 +- test %r11, %r11 +- je .L_get_AAD_done\@ +- +- vpxor \T7, \T7, \T7 +- +- /* read the last <16B of AAD. since we have at least 4B of +- data right after the AAD (the ICV, and maybe some CT), we can +- read 4B/8B blocks safely, and then get rid of the extra stuff */ +-.L_get_AAD_rest8\@: +- cmp $4, %r11 +- jle .L_get_AAD_rest4\@ +- movq (%r10), \T1 +- add $8, %r10 +- sub $8, %r11 +- vpslldq $8, \T1, \T1 +- vpsrldq $8, \T7, \T7 +- vpxor \T1, \T7, \T7 +- jmp .L_get_AAD_rest8\@ +-.L_get_AAD_rest4\@: +- test %r11, %r11 +- jle .L_get_AAD_rest0\@ +- mov (%r10), %eax +- movq %rax, \T1 +- add $4, %r10 +- sub $4, %r11 +- vpslldq $12, \T1, \T1 +- vpsrldq $4, \T7, \T7 +- vpxor \T1, \T7, \T7 +-.L_get_AAD_rest0\@: +- /* finalize: shift out the extra bytes we read, and align +- left. since pslldq can only shift by an immediate, we use +- vpshufb and a pair of shuffle masks */ +- leaq ALL_F(%rip), %r11 +- subq %r12, %r11 +- vmovdqu 16(%r11), \T1 +- andq $~3, %r11 +- vpshufb (%r11), \T7, \T7 +- vpand \T1, \T7, \T7 +-.L_get_AAD_rest_final\@: +- vpshufb SHUF_MASK(%rip), \T7, \T7 +- vpxor \T8, \T7, \T7 +- \GHASH_MUL \T7, \T2, \T1, \T3, \T4, \T5, \T6 +- +-.L_get_AAD_done\@: +- vmovdqu \T7, AadHash(arg2) +-.endm +- +-.macro INIT GHASH_MUL PRECOMPUTE +- mov arg6, %r11 +- mov %r11, AadLen(arg2) # ctx_data.aad_length = aad_length +- xor %r11d, %r11d +- mov %r11, InLen(arg2) # ctx_data.in_length = 0 +- +- mov %r11, PBlockLen(arg2) # ctx_data.partial_block_length = 0 +- mov %r11, PBlockEncKey(arg2) # ctx_data.partial_block_enc_key = 0 +- mov arg3, %rax +- movdqu (%rax), %xmm0 +- movdqu %xmm0, OrigIV(arg2) # ctx_data.orig_IV = iv +- +- vpshufb SHUF_MASK(%rip), %xmm0, %xmm0 +- movdqu %xmm0, CurCount(arg2) # ctx_data.current_counter = iv +- +- vmovdqu (arg4), %xmm6 # xmm6 = HashKey +- +- vpshufb SHUF_MASK(%rip), %xmm6, %xmm6 +- ############### PRECOMPUTATION of HashKey<<1 mod poly from the HashKey +- vmovdqa %xmm6, %xmm2 +- vpsllq $1, %xmm6, %xmm6 +- vpsrlq $63, %xmm2, %xmm2 +- vmovdqa %xmm2, %xmm1 +- vpslldq $8, %xmm2, %xmm2 +- vpsrldq $8, %xmm1, %xmm1 +- vpor %xmm2, %xmm6, %xmm6 +- #reduction +- vpshufd $0b00100100, %xmm1, %xmm2 +- vpcmpeqd TWOONE(%rip), %xmm2, %xmm2 +- vpand POLY(%rip), %xmm2, %xmm2 +- vpxor %xmm2, %xmm6, %xmm6 # xmm6 holds the HashKey<<1 mod poly +- ####################################################################### +- vmovdqu %xmm6, HashKey(arg2) # store HashKey<<1 mod poly +- +- CALC_AAD_HASH \GHASH_MUL, arg5, arg6, %xmm2, %xmm6, %xmm3, %xmm4, %xmm5, %xmm7, %xmm1, %xmm0 +- +- \PRECOMPUTE %xmm6, %xmm0, %xmm1, %xmm2, %xmm3, %xmm4, %xmm5 +-.endm +- +- +-# Reads DLEN bytes starting at DPTR and stores in XMMDst +-# where 0 < DLEN < 16 +-# Clobbers %rax, DLEN +-.macro READ_PARTIAL_BLOCK DPTR DLEN XMMDst +- vpxor \XMMDst, \XMMDst, \XMMDst +- +- cmp $8, \DLEN +- jl .L_read_lt8_\@ +- mov (\DPTR), %rax +- vpinsrq $0, %rax, \XMMDst, \XMMDst +- sub $8, \DLEN +- jz .L_done_read_partial_block_\@ +- xor %eax, %eax +-.L_read_next_byte_\@: +- shl $8, %rax +- mov 7(\DPTR, \DLEN, 1), %al +- dec \DLEN +- jnz .L_read_next_byte_\@ +- vpinsrq $1, %rax, \XMMDst, \XMMDst +- jmp .L_done_read_partial_block_\@ +-.L_read_lt8_\@: +- xor %eax, %eax +-.L_read_next_byte_lt8_\@: +- shl $8, %rax +- mov -1(\DPTR, \DLEN, 1), %al +- dec \DLEN +- jnz .L_read_next_byte_lt8_\@ +- vpinsrq $0, %rax, \XMMDst, \XMMDst +-.L_done_read_partial_block_\@: +-.endm +- +-# PARTIAL_BLOCK: Handles encryption/decryption and the tag partial blocks +-# between update calls. +-# Requires the input data be at least 1 byte long due to READ_PARTIAL_BLOCK +-# Outputs encrypted bytes, and updates hash and partial info in gcm_data_context +-# Clobbers rax, r10, r12, r13, xmm0-6, xmm9-13 +-.macro PARTIAL_BLOCK GHASH_MUL CYPH_PLAIN_OUT PLAIN_CYPH_IN PLAIN_CYPH_LEN DATA_OFFSET \ +- AAD_HASH ENC_DEC +- mov PBlockLen(arg2), %r13 +- test %r13, %r13 +- je .L_partial_block_done_\@ # Leave Macro if no partial blocks +- # Read in input data without over reading +- cmp $16, \PLAIN_CYPH_LEN +- jl .L_fewer_than_16_bytes_\@ +- vmovdqu (\PLAIN_CYPH_IN), %xmm1 # If more than 16 bytes, just fill xmm +- jmp .L_data_read_\@ +- +-.L_fewer_than_16_bytes_\@: +- lea (\PLAIN_CYPH_IN, \DATA_OFFSET, 1), %r10 +- mov \PLAIN_CYPH_LEN, %r12 +- READ_PARTIAL_BLOCK %r10 %r12 %xmm1 +- +- mov PBlockLen(arg2), %r13 +- +-.L_data_read_\@: # Finished reading in data +- +- vmovdqu PBlockEncKey(arg2), %xmm9 +- vmovdqu HashKey(arg2), %xmm13 +- +- lea SHIFT_MASK(%rip), %r12 +- +- # adjust the shuffle mask pointer to be able to shift r13 bytes +- # r16-r13 is the number of bytes in plaintext mod 16) +- add %r13, %r12 +- vmovdqu (%r12), %xmm2 # get the appropriate shuffle mask +- vpshufb %xmm2, %xmm9, %xmm9 # shift right r13 bytes +- +-.if \ENC_DEC == DEC +- vmovdqa %xmm1, %xmm3 +- pxor %xmm1, %xmm9 # Ciphertext XOR E(K, Yn) +- +- mov \PLAIN_CYPH_LEN, %r10 +- add %r13, %r10 +- # Set r10 to be the amount of data left in CYPH_PLAIN_IN after filling +- sub $16, %r10 +- # Determine if partial block is not being filled and +- # shift mask accordingly +- jge .L_no_extra_mask_1_\@ +- sub %r10, %r12 +-.L_no_extra_mask_1_\@: +- +- vmovdqu ALL_F-SHIFT_MASK(%r12), %xmm1 +- # get the appropriate mask to mask out bottom r13 bytes of xmm9 +- vpand %xmm1, %xmm9, %xmm9 # mask out bottom r13 bytes of xmm9 +- +- vpand %xmm1, %xmm3, %xmm3 +- vmovdqa SHUF_MASK(%rip), %xmm10 +- vpshufb %xmm10, %xmm3, %xmm3 +- vpshufb %xmm2, %xmm3, %xmm3 +- vpxor %xmm3, \AAD_HASH, \AAD_HASH +- +- test %r10, %r10 +- jl .L_partial_incomplete_1_\@ +- +- # GHASH computation for the last <16 Byte block +- \GHASH_MUL \AAD_HASH, %xmm13, %xmm0, %xmm10, %xmm11, %xmm5, %xmm6 +- xor %eax,%eax +- +- mov %rax, PBlockLen(arg2) +- jmp .L_dec_done_\@ +-.L_partial_incomplete_1_\@: +- add \PLAIN_CYPH_LEN, PBlockLen(arg2) +-.L_dec_done_\@: +- vmovdqu \AAD_HASH, AadHash(arg2) +-.else +- vpxor %xmm1, %xmm9, %xmm9 # Plaintext XOR E(K, Yn) +- +- mov \PLAIN_CYPH_LEN, %r10 +- add %r13, %r10 +- # Set r10 to be the amount of data left in CYPH_PLAIN_IN after filling +- sub $16, %r10 +- # Determine if partial block is not being filled and +- # shift mask accordingly +- jge .L_no_extra_mask_2_\@ +- sub %r10, %r12 +-.L_no_extra_mask_2_\@: +- +- vmovdqu ALL_F-SHIFT_MASK(%r12), %xmm1 +- # get the appropriate mask to mask out bottom r13 bytes of xmm9 +- vpand %xmm1, %xmm9, %xmm9 +- +- vmovdqa SHUF_MASK(%rip), %xmm1 +- vpshufb %xmm1, %xmm9, %xmm9 +- vpshufb %xmm2, %xmm9, %xmm9 +- vpxor %xmm9, \AAD_HASH, \AAD_HASH +- +- test %r10, %r10 +- jl .L_partial_incomplete_2_\@ +- +- # GHASH computation for the last <16 Byte block +- \GHASH_MUL \AAD_HASH, %xmm13, %xmm0, %xmm10, %xmm11, %xmm5, %xmm6 +- xor %eax,%eax +- +- mov %rax, PBlockLen(arg2) +- jmp .L_encode_done_\@ +-.L_partial_incomplete_2_\@: +- add \PLAIN_CYPH_LEN, PBlockLen(arg2) +-.L_encode_done_\@: +- vmovdqu \AAD_HASH, AadHash(arg2) +- +- vmovdqa SHUF_MASK(%rip), %xmm10 +- # shuffle xmm9 back to output as ciphertext +- vpshufb %xmm10, %xmm9, %xmm9 +- vpshufb %xmm2, %xmm9, %xmm9 +-.endif +- # output encrypted Bytes +- test %r10, %r10 +- jl .L_partial_fill_\@ +- mov %r13, %r12 +- mov $16, %r13 +- # Set r13 to be the number of bytes to write out +- sub %r12, %r13 +- jmp .L_count_set_\@ +-.L_partial_fill_\@: +- mov \PLAIN_CYPH_LEN, %r13 +-.L_count_set_\@: +- vmovdqa %xmm9, %xmm0 +- vmovq %xmm0, %rax +- cmp $8, %r13 +- jle .L_less_than_8_bytes_left_\@ +- +- mov %rax, (\CYPH_PLAIN_OUT, \DATA_OFFSET, 1) +- add $8, \DATA_OFFSET +- psrldq $8, %xmm0 +- vmovq %xmm0, %rax +- sub $8, %r13 +-.L_less_than_8_bytes_left_\@: +- movb %al, (\CYPH_PLAIN_OUT, \DATA_OFFSET, 1) +- add $1, \DATA_OFFSET +- shr $8, %rax +- sub $1, %r13 +- jne .L_less_than_8_bytes_left_\@ +-.L_partial_block_done_\@: +-.endm # PARTIAL_BLOCK +- +-############################################################################### +-# GHASH_MUL MACRO to implement: Data*HashKey mod (128,127,126,121,0) +-# Input: A and B (128-bits each, bit-reflected) +-# Output: C = A*B*x mod poly, (i.e. >>1 ) +-# To compute GH = GH*HashKey mod poly, give HK = HashKey<<1 mod poly as input +-# GH = GH * HK * x mod poly which is equivalent to GH*HashKey mod poly. +-############################################################################### +-.macro GHASH_MUL_AVX GH HK T1 T2 T3 T4 T5 +- +- vpshufd $0b01001110, \GH, \T2 +- vpshufd $0b01001110, \HK, \T3 +- vpxor \GH , \T2, \T2 # T2 = (a1+a0) +- vpxor \HK , \T3, \T3 # T3 = (b1+b0) +- +- vpclmulqdq $0x11, \HK, \GH, \T1 # T1 = a1*b1 +- vpclmulqdq $0x00, \HK, \GH, \GH # GH = a0*b0 +- vpclmulqdq $0x00, \T3, \T2, \T2 # T2 = (a1+a0)*(b1+b0) +- vpxor \GH, \T2,\T2 +- vpxor \T1, \T2,\T2 # T2 = a0*b1+a1*b0 +- +- vpslldq $8, \T2,\T3 # shift-L T3 2 DWs +- vpsrldq $8, \T2,\T2 # shift-R T2 2 DWs +- vpxor \T3, \GH, \GH +- vpxor \T2, \T1, \T1 # = GH x HK +- +- #first phase of the reduction +- vpslld $31, \GH, \T2 # packed right shifting << 31 +- vpslld $30, \GH, \T3 # packed right shifting shift << 30 +- vpslld $25, \GH, \T4 # packed right shifting shift << 25 +- +- vpxor \T3, \T2, \T2 # xor the shifted versions +- vpxor \T4, \T2, \T2 +- +- vpsrldq $4, \T2, \T5 # shift-R T5 1 DW +- +- vpslldq $12, \T2, \T2 # shift-L T2 3 DWs +- vpxor \T2, \GH, \GH # first phase of the reduction complete +- +- #second phase of the reduction +- +- vpsrld $1,\GH, \T2 # packed left shifting >> 1 +- vpsrld $2,\GH, \T3 # packed left shifting >> 2 +- vpsrld $7,\GH, \T4 # packed left shifting >> 7 +- vpxor \T3, \T2, \T2 # xor the shifted versions +- vpxor \T4, \T2, \T2 +- +- vpxor \T5, \T2, \T2 +- vpxor \T2, \GH, \GH +- vpxor \T1, \GH, \GH # the result is in GH +- +- +-.endm +- +-.macro PRECOMPUTE_AVX HK T1 T2 T3 T4 T5 T6 +- +- # Haskey_i_k holds XORed values of the low and high parts of the Haskey_i +- vmovdqa \HK, \T5 +- +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^2<<1 mod poly +- vmovdqu \T5, HashKey_2(arg2) # [HashKey_2] = HashKey^2<<1 mod poly +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_2_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^3<<1 mod poly +- vmovdqu \T5, HashKey_3(arg2) +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_3_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^4<<1 mod poly +- vmovdqu \T5, HashKey_4(arg2) +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_4_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^5<<1 mod poly +- vmovdqu \T5, HashKey_5(arg2) +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_5_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^6<<1 mod poly +- vmovdqu \T5, HashKey_6(arg2) +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_6_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^7<<1 mod poly +- vmovdqu \T5, HashKey_7(arg2) +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_7_k(arg2) +- +- GHASH_MUL_AVX \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^8<<1 mod poly +- vmovdqu \T5, HashKey_8(arg2) +- vpshufd $0b01001110, \T5, \T1 +- vpxor \T5, \T1, \T1 +- vmovdqu \T1, HashKey_8_k(arg2) +- +-.endm +- +-## if a = number of total plaintext bytes +-## b = floor(a/16) +-## num_initial_blocks = b mod 4# +-## encrypt the initial num_initial_blocks blocks and apply ghash on the ciphertext +-## r10, r11, r12, rax are clobbered +-## arg1, arg2, arg3, arg4 are used as pointers only, not modified +- +-.macro INITIAL_BLOCKS_AVX REP num_initial_blocks T1 T2 T3 T4 T5 CTR XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 T6 T_key ENC_DEC +- i = (8-\num_initial_blocks) +- setreg +- vmovdqu AadHash(arg2), reg_i +- +- # start AES for num_initial_blocks blocks +- vmovdqu CurCount(arg2), \CTR +- +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, reg_i +- vpshufb SHUF_MASK(%rip), reg_i, reg_i # perform a 16Byte swap +- i = (i+1) +- setreg +-.endr +- +- vmovdqa (arg1), \T_key +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vpxor \T_key, reg_i, reg_i +- i = (i+1) +- setreg +-.endr +- +- j = 1 +- setreg +-.rep \REP +- vmovdqa 16*j(arg1), \T_key +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vaesenc \T_key, reg_i, reg_i +- i = (i+1) +- setreg +-.endr +- +- j = (j+1) +- setreg +-.endr +- +- vmovdqa 16*j(arg1), \T_key +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vaesenclast \T_key, reg_i, reg_i +- i = (i+1) +- setreg +-.endr +- +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vmovdqu (arg4, %r11), \T1 +- vpxor \T1, reg_i, reg_i +- vmovdqu reg_i, (arg3 , %r11) # write back ciphertext for num_initial_blocks blocks +- add $16, %r11 +-.if \ENC_DEC == DEC +- vmovdqa \T1, reg_i +-.endif +- vpshufb SHUF_MASK(%rip), reg_i, reg_i # prepare ciphertext for GHASH computations +- i = (i+1) +- setreg +-.endr +- +- +- i = (8-\num_initial_blocks) +- j = (9-\num_initial_blocks) +- setreg +- +-.rep \num_initial_blocks +- vpxor reg_i, reg_j, reg_j +- GHASH_MUL_AVX reg_j, \T2, \T1, \T3, \T4, \T5, \T6 # apply GHASH on num_initial_blocks blocks +- i = (i+1) +- j = (j+1) +- setreg +-.endr +- # XMM8 has the combined result here +- +- vmovdqa \XMM8, TMP1(%rsp) +- vmovdqa \XMM8, \T3 +- +- cmp $128, %r13 +- jl .L_initial_blocks_done\@ # no need for precomputed constants +- +-############################################################################### +-# Haskey_i_k holds XORed values of the low and high parts of the Haskey_i +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM1 +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM2 +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM3 +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM4 +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM5 +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM6 +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM7 +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM8 +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +- +- vmovdqa (arg1), \T_key +- vpxor \T_key, \XMM1, \XMM1 +- vpxor \T_key, \XMM2, \XMM2 +- vpxor \T_key, \XMM3, \XMM3 +- vpxor \T_key, \XMM4, \XMM4 +- vpxor \T_key, \XMM5, \XMM5 +- vpxor \T_key, \XMM6, \XMM6 +- vpxor \T_key, \XMM7, \XMM7 +- vpxor \T_key, \XMM8, \XMM8 +- +- i = 1 +- setreg +-.rep \REP # do REP rounds +- vmovdqa 16*i(arg1), \T_key +- vaesenc \T_key, \XMM1, \XMM1 +- vaesenc \T_key, \XMM2, \XMM2 +- vaesenc \T_key, \XMM3, \XMM3 +- vaesenc \T_key, \XMM4, \XMM4 +- vaesenc \T_key, \XMM5, \XMM5 +- vaesenc \T_key, \XMM6, \XMM6 +- vaesenc \T_key, \XMM7, \XMM7 +- vaesenc \T_key, \XMM8, \XMM8 +- i = (i+1) +- setreg +-.endr +- +- vmovdqa 16*i(arg1), \T_key +- vaesenclast \T_key, \XMM1, \XMM1 +- vaesenclast \T_key, \XMM2, \XMM2 +- vaesenclast \T_key, \XMM3, \XMM3 +- vaesenclast \T_key, \XMM4, \XMM4 +- vaesenclast \T_key, \XMM5, \XMM5 +- vaesenclast \T_key, \XMM6, \XMM6 +- vaesenclast \T_key, \XMM7, \XMM7 +- vaesenclast \T_key, \XMM8, \XMM8 +- +- vmovdqu (arg4, %r11), \T1 +- vpxor \T1, \XMM1, \XMM1 +- vmovdqu \XMM1, (arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM1 +- .endif +- +- vmovdqu 16*1(arg4, %r11), \T1 +- vpxor \T1, \XMM2, \XMM2 +- vmovdqu \XMM2, 16*1(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM2 +- .endif +- +- vmovdqu 16*2(arg4, %r11), \T1 +- vpxor \T1, \XMM3, \XMM3 +- vmovdqu \XMM3, 16*2(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM3 +- .endif +- +- vmovdqu 16*3(arg4, %r11), \T1 +- vpxor \T1, \XMM4, \XMM4 +- vmovdqu \XMM4, 16*3(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM4 +- .endif +- +- vmovdqu 16*4(arg4, %r11), \T1 +- vpxor \T1, \XMM5, \XMM5 +- vmovdqu \XMM5, 16*4(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM5 +- .endif +- +- vmovdqu 16*5(arg4, %r11), \T1 +- vpxor \T1, \XMM6, \XMM6 +- vmovdqu \XMM6, 16*5(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM6 +- .endif +- +- vmovdqu 16*6(arg4, %r11), \T1 +- vpxor \T1, \XMM7, \XMM7 +- vmovdqu \XMM7, 16*6(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM7 +- .endif +- +- vmovdqu 16*7(arg4, %r11), \T1 +- vpxor \T1, \XMM8, \XMM8 +- vmovdqu \XMM8, 16*7(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM8 +- .endif +- +- add $128, %r11 +- +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- vpxor TMP1(%rsp), \XMM1, \XMM1 # combine GHASHed value with the corresponding ciphertext +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +- +-############################################################################### +- +-.L_initial_blocks_done\@: +- +-.endm +- +-# encrypt 8 blocks at a time +-# ghash the 8 previously encrypted ciphertext blocks +-# arg1, arg2, arg3, arg4 are used as pointers only, not modified +-# r11 is the data offset value +-.macro GHASH_8_ENCRYPT_8_PARALLEL_AVX REP T1 T2 T3 T4 T5 T6 CTR XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 T7 loop_idx ENC_DEC +- +- vmovdqa \XMM1, \T2 +- vmovdqa \XMM2, TMP2(%rsp) +- vmovdqa \XMM3, TMP3(%rsp) +- vmovdqa \XMM4, TMP4(%rsp) +- vmovdqa \XMM5, TMP5(%rsp) +- vmovdqa \XMM6, TMP6(%rsp) +- vmovdqa \XMM7, TMP7(%rsp) +- vmovdqa \XMM8, TMP8(%rsp) +- +-.if \loop_idx == in_order +- vpaddd ONE(%rip), \CTR, \XMM1 # INCR CNT +- vpaddd ONE(%rip), \XMM1, \XMM2 +- vpaddd ONE(%rip), \XMM2, \XMM3 +- vpaddd ONE(%rip), \XMM3, \XMM4 +- vpaddd ONE(%rip), \XMM4, \XMM5 +- vpaddd ONE(%rip), \XMM5, \XMM6 +- vpaddd ONE(%rip), \XMM6, \XMM7 +- vpaddd ONE(%rip), \XMM7, \XMM8 +- vmovdqa \XMM8, \CTR +- +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +-.else +- vpaddd ONEf(%rip), \CTR, \XMM1 # INCR CNT +- vpaddd ONEf(%rip), \XMM1, \XMM2 +- vpaddd ONEf(%rip), \XMM2, \XMM3 +- vpaddd ONEf(%rip), \XMM3, \XMM4 +- vpaddd ONEf(%rip), \XMM4, \XMM5 +- vpaddd ONEf(%rip), \XMM5, \XMM6 +- vpaddd ONEf(%rip), \XMM6, \XMM7 +- vpaddd ONEf(%rip), \XMM7, \XMM8 +- vmovdqa \XMM8, \CTR +-.endif +- +- +- ####################################################################### +- +- vmovdqu (arg1), \T1 +- vpxor \T1, \XMM1, \XMM1 +- vpxor \T1, \XMM2, \XMM2 +- vpxor \T1, \XMM3, \XMM3 +- vpxor \T1, \XMM4, \XMM4 +- vpxor \T1, \XMM5, \XMM5 +- vpxor \T1, \XMM6, \XMM6 +- vpxor \T1, \XMM7, \XMM7 +- vpxor \T1, \XMM8, \XMM8 +- +- ####################################################################### +- +- +- +- +- +- vmovdqu 16*1(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqu 16*2(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- +- ####################################################################### +- +- vmovdqu HashKey_8(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T2, \T4 # T4 = a1*b1 +- vpclmulqdq $0x00, \T5, \T2, \T7 # T7 = a0*b0 +- +- vpshufd $0b01001110, \T2, \T6 +- vpxor \T2, \T6, \T6 +- +- vmovdqu HashKey_8_k(arg2), \T5 +- vpclmulqdq $0x00, \T5, \T6, \T6 +- +- vmovdqu 16*3(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP2(%rsp), \T1 +- vmovdqu HashKey_7(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_7_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*4(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- ####################################################################### +- +- vmovdqa TMP3(%rsp), \T1 +- vmovdqu HashKey_6(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_6_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*5(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP4(%rsp), \T1 +- vmovdqu HashKey_5(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_5_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*6(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- +- vmovdqa TMP5(%rsp), \T1 +- vmovdqu HashKey_4(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_4_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*7(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP6(%rsp), \T1 +- vmovdqu HashKey_3(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_3_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- +- vmovdqu 16*8(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP7(%rsp), \T1 +- vmovdqu HashKey_2(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_2_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- ####################################################################### +- +- vmovdqu 16*9(arg1), \T5 +- vaesenc \T5, \XMM1, \XMM1 +- vaesenc \T5, \XMM2, \XMM2 +- vaesenc \T5, \XMM3, \XMM3 +- vaesenc \T5, \XMM4, \XMM4 +- vaesenc \T5, \XMM5, \XMM5 +- vaesenc \T5, \XMM6, \XMM6 +- vaesenc \T5, \XMM7, \XMM7 +- vaesenc \T5, \XMM8, \XMM8 +- +- vmovdqa TMP8(%rsp), \T1 +- vmovdqu HashKey(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpshufd $0b01001110, \T1, \T3 +- vpxor \T1, \T3, \T3 +- vmovdqu HashKey_k(arg2), \T5 +- vpclmulqdq $0x10, \T5, \T3, \T3 +- vpxor \T3, \T6, \T6 +- +- vpxor \T4, \T6, \T6 +- vpxor \T7, \T6, \T6 +- +- vmovdqu 16*10(arg1), \T5 +- +- i = 11 +- setreg +-.rep (\REP-9) +- +- vaesenc \T5, \XMM1, \XMM1 +- vaesenc \T5, \XMM2, \XMM2 +- vaesenc \T5, \XMM3, \XMM3 +- vaesenc \T5, \XMM4, \XMM4 +- vaesenc \T5, \XMM5, \XMM5 +- vaesenc \T5, \XMM6, \XMM6 +- vaesenc \T5, \XMM7, \XMM7 +- vaesenc \T5, \XMM8, \XMM8 +- +- vmovdqu 16*i(arg1), \T5 +- i = i + 1 +- setreg +-.endr +- +- i = 0 +- j = 1 +- setreg +-.rep 8 +- vpxor 16*i(arg4, %r11), \T5, \T2 +- .if \ENC_DEC == ENC +- vaesenclast \T2, reg_j, reg_j +- .else +- vaesenclast \T2, reg_j, \T3 +- vmovdqu 16*i(arg4, %r11), reg_j +- vmovdqu \T3, 16*i(arg3, %r11) +- .endif +- i = (i+1) +- j = (j+1) +- setreg +-.endr +- ####################################################################### +- +- +- vpslldq $8, \T6, \T3 # shift-L T3 2 DWs +- vpsrldq $8, \T6, \T6 # shift-R T2 2 DWs +- vpxor \T3, \T7, \T7 +- vpxor \T4, \T6, \T6 # accumulate the results in T6:T7 +- +- +- +- ####################################################################### +- #first phase of the reduction +- ####################################################################### +- vpslld $31, \T7, \T2 # packed right shifting << 31 +- vpslld $30, \T7, \T3 # packed right shifting shift << 30 +- vpslld $25, \T7, \T4 # packed right shifting shift << 25 +- +- vpxor \T3, \T2, \T2 # xor the shifted versions +- vpxor \T4, \T2, \T2 +- +- vpsrldq $4, \T2, \T1 # shift-R T1 1 DW +- +- vpslldq $12, \T2, \T2 # shift-L T2 3 DWs +- vpxor \T2, \T7, \T7 # first phase of the reduction complete +- ####################################################################### +- .if \ENC_DEC == ENC +- vmovdqu \XMM1, 16*0(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM2, 16*1(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM3, 16*2(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM4, 16*3(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM5, 16*4(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM6, 16*5(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM7, 16*6(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM8, 16*7(arg3,%r11) # Write to the Ciphertext buffer +- .endif +- +- ####################################################################### +- #second phase of the reduction +- vpsrld $1, \T7, \T2 # packed left shifting >> 1 +- vpsrld $2, \T7, \T3 # packed left shifting >> 2 +- vpsrld $7, \T7, \T4 # packed left shifting >> 7 +- vpxor \T3, \T2, \T2 # xor the shifted versions +- vpxor \T4, \T2, \T2 +- +- vpxor \T1, \T2, \T2 +- vpxor \T2, \T7, \T7 +- vpxor \T7, \T6, \T6 # the result is in T6 +- ####################################################################### +- +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +- +- +- vpxor \T6, \XMM1, \XMM1 +- +- +- +-.endm +- +- +-# GHASH the last 4 ciphertext blocks. +-.macro GHASH_LAST_8_AVX T1 T2 T3 T4 T5 T6 T7 XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 +- +- ## Karatsuba Method +- +- +- vpshufd $0b01001110, \XMM1, \T2 +- vpxor \XMM1, \T2, \T2 +- vmovdqu HashKey_8(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM1, \T6 +- vpclmulqdq $0x00, \T5, \XMM1, \T7 +- +- vmovdqu HashKey_8_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM2, \T2 +- vpxor \XMM2, \T2, \T2 +- vmovdqu HashKey_7(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM2, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM2, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_7_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM3, \T2 +- vpxor \XMM3, \T2, \T2 +- vmovdqu HashKey_6(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM3, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM3, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_6_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM4, \T2 +- vpxor \XMM4, \T2, \T2 +- vmovdqu HashKey_5(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM4, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM4, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_5_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM5, \T2 +- vpxor \XMM5, \T2, \T2 +- vmovdqu HashKey_4(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM5, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM5, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_4_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM6, \T2 +- vpxor \XMM6, \T2, \T2 +- vmovdqu HashKey_3(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM6, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM6, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_3_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM7, \T2 +- vpxor \XMM7, \T2, \T2 +- vmovdqu HashKey_2(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM7, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM7, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_2_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vpshufd $0b01001110, \XMM8, \T2 +- vpxor \XMM8, \T2, \T2 +- vmovdqu HashKey(arg2), \T5 +- vpclmulqdq $0x11, \T5, \XMM8, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM8, \T4 +- vpxor \T4, \T7, \T7 +- +- vmovdqu HashKey_k(arg2), \T3 +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- vpxor \T6, \XMM1, \XMM1 +- vpxor \T7, \XMM1, \T2 +- +- +- +- +- vpslldq $8, \T2, \T4 +- vpsrldq $8, \T2, \T2 +- +- vpxor \T4, \T7, \T7 +- vpxor \T2, \T6, \T6 # holds the result of +- # the accumulated carry-less multiplications +- +- ####################################################################### +- #first phase of the reduction +- vpslld $31, \T7, \T2 # packed right shifting << 31 +- vpslld $30, \T7, \T3 # packed right shifting shift << 30 +- vpslld $25, \T7, \T4 # packed right shifting shift << 25 +- +- vpxor \T3, \T2, \T2 # xor the shifted versions +- vpxor \T4, \T2, \T2 +- +- vpsrldq $4, \T2, \T1 # shift-R T1 1 DW +- +- vpslldq $12, \T2, \T2 # shift-L T2 3 DWs +- vpxor \T2, \T7, \T7 # first phase of the reduction complete +- ####################################################################### +- +- +- #second phase of the reduction +- vpsrld $1, \T7, \T2 # packed left shifting >> 1 +- vpsrld $2, \T7, \T3 # packed left shifting >> 2 +- vpsrld $7, \T7, \T4 # packed left shifting >> 7 +- vpxor \T3, \T2, \T2 # xor the shifted versions +- vpxor \T4, \T2, \T2 +- +- vpxor \T1, \T2, \T2 +- vpxor \T2, \T7, \T7 +- vpxor \T7, \T6, \T6 # the result is in T6 +- +-.endm +- +-############################################################# +-#void aesni_gcm_precomp_avx_gen2 +-# (gcm_data *my_ctx_data, +-# gcm_context_data *data, +-# u8 *hash_subkey# /* H, the Hash sub key input. Data starts on a 16-byte boundary. */ +-# u8 *iv, /* Pre-counter block j0: 4 byte salt +-# (from Security Association) concatenated with 8 byte +-# Initialisation Vector (from IPSec ESP Payload) +-# concatenated with 0x00000001. 16-byte aligned pointer. */ +-# const u8 *aad, /* Additional Authentication Data (AAD)*/ +-# u64 aad_len) /* Length of AAD in bytes. With RFC4106 this is going to be 8 or 12 Bytes */ +-############################################################# +-SYM_FUNC_START(aesni_gcm_init_avx_gen2) +- FUNC_SAVE +- INIT GHASH_MUL_AVX, PRECOMPUTE_AVX +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_init_avx_gen2) +- +-############################################################################### +-#void aesni_gcm_enc_update_avx_gen2( +-# gcm_data *my_ctx_data, /* aligned to 16 Bytes */ +-# gcm_context_data *data, +-# u8 *out, /* Ciphertext output. Encrypt in-place is allowed. */ +-# const u8 *in, /* Plaintext input */ +-# u64 plaintext_len) /* Length of data in Bytes for encryption. */ +-############################################################################### +-SYM_FUNC_START(aesni_gcm_enc_update_avx_gen2) +- FUNC_SAVE +- mov keysize, %eax +- cmp $32, %eax +- je key_256_enc_update +- cmp $16, %eax +- je key_128_enc_update +- # must be 192 +- GCM_ENC_DEC INITIAL_BLOCKS_AVX, GHASH_8_ENCRYPT_8_PARALLEL_AVX, GHASH_LAST_8_AVX, GHASH_MUL_AVX, ENC, 11 +- FUNC_RESTORE +- RET +-key_128_enc_update: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX, GHASH_8_ENCRYPT_8_PARALLEL_AVX, GHASH_LAST_8_AVX, GHASH_MUL_AVX, ENC, 9 +- FUNC_RESTORE +- RET +-key_256_enc_update: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX, GHASH_8_ENCRYPT_8_PARALLEL_AVX, GHASH_LAST_8_AVX, GHASH_MUL_AVX, ENC, 13 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_enc_update_avx_gen2) +- +-############################################################################### +-#void aesni_gcm_dec_update_avx_gen2( +-# gcm_data *my_ctx_data, /* aligned to 16 Bytes */ +-# gcm_context_data *data, +-# u8 *out, /* Plaintext output. Decrypt in-place is allowed. */ +-# const u8 *in, /* Ciphertext input */ +-# u64 plaintext_len) /* Length of data in Bytes for encryption. */ +-############################################################################### +-SYM_FUNC_START(aesni_gcm_dec_update_avx_gen2) +- FUNC_SAVE +- mov keysize,%eax +- cmp $32, %eax +- je key_256_dec_update +- cmp $16, %eax +- je key_128_dec_update +- # must be 192 +- GCM_ENC_DEC INITIAL_BLOCKS_AVX, GHASH_8_ENCRYPT_8_PARALLEL_AVX, GHASH_LAST_8_AVX, GHASH_MUL_AVX, DEC, 11 +- FUNC_RESTORE +- RET +-key_128_dec_update: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX, GHASH_8_ENCRYPT_8_PARALLEL_AVX, GHASH_LAST_8_AVX, GHASH_MUL_AVX, DEC, 9 +- FUNC_RESTORE +- RET +-key_256_dec_update: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX, GHASH_8_ENCRYPT_8_PARALLEL_AVX, GHASH_LAST_8_AVX, GHASH_MUL_AVX, DEC, 13 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_dec_update_avx_gen2) +- +-############################################################################### +-#void aesni_gcm_finalize_avx_gen2( +-# gcm_data *my_ctx_data, /* aligned to 16 Bytes */ +-# gcm_context_data *data, +-# u8 *auth_tag, /* Authenticated Tag output. */ +-# u64 auth_tag_len)# /* Authenticated Tag Length in bytes. +-# Valid values are 16 (most likely), 12 or 8. */ +-############################################################################### +-SYM_FUNC_START(aesni_gcm_finalize_avx_gen2) +- FUNC_SAVE +- mov keysize,%eax +- cmp $32, %eax +- je key_256_finalize +- cmp $16, %eax +- je key_128_finalize +- # must be 192 +- GCM_COMPLETE GHASH_MUL_AVX, 11, arg3, arg4 +- FUNC_RESTORE +- RET +-key_128_finalize: +- GCM_COMPLETE GHASH_MUL_AVX, 9, arg3, arg4 +- FUNC_RESTORE +- RET +-key_256_finalize: +- GCM_COMPLETE GHASH_MUL_AVX, 13, arg3, arg4 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_finalize_avx_gen2) +- +-############################################################################### +-# GHASH_MUL MACRO to implement: Data*HashKey mod (128,127,126,121,0) +-# Input: A and B (128-bits each, bit-reflected) +-# Output: C = A*B*x mod poly, (i.e. >>1 ) +-# To compute GH = GH*HashKey mod poly, give HK = HashKey<<1 mod poly as input +-# GH = GH * HK * x mod poly which is equivalent to GH*HashKey mod poly. +-############################################################################### +-.macro GHASH_MUL_AVX2 GH HK T1 T2 T3 T4 T5 +- +- vpclmulqdq $0x11,\HK,\GH,\T1 # T1 = a1*b1 +- vpclmulqdq $0x00,\HK,\GH,\T2 # T2 = a0*b0 +- vpclmulqdq $0x01,\HK,\GH,\T3 # T3 = a1*b0 +- vpclmulqdq $0x10,\HK,\GH,\GH # GH = a0*b1 +- vpxor \T3, \GH, \GH +- +- +- vpsrldq $8 , \GH, \T3 # shift-R GH 2 DWs +- vpslldq $8 , \GH, \GH # shift-L GH 2 DWs +- +- vpxor \T3, \T1, \T1 +- vpxor \T2, \GH, \GH +- +- ####################################################################### +- #first phase of the reduction +- vmovdqa POLY2(%rip), \T3 +- +- vpclmulqdq $0x01, \GH, \T3, \T2 +- vpslldq $8, \T2, \T2 # shift-L T2 2 DWs +- +- vpxor \T2, \GH, \GH # first phase of the reduction complete +- ####################################################################### +- #second phase of the reduction +- vpclmulqdq $0x00, \GH, \T3, \T2 +- vpsrldq $4, \T2, \T2 # shift-R T2 1 DW (Shift-R only 1-DW to obtain 2-DWs shift-R) +- +- vpclmulqdq $0x10, \GH, \T3, \GH +- vpslldq $4, \GH, \GH # shift-L GH 1 DW (Shift-L 1-DW to obtain result with no shifts) +- +- vpxor \T2, \GH, \GH # second phase of the reduction complete +- ####################################################################### +- vpxor \T1, \GH, \GH # the result is in GH +- +- +-.endm +- +-.macro PRECOMPUTE_AVX2 HK T1 T2 T3 T4 T5 T6 +- +- # Haskey_i_k holds XORed values of the low and high parts of the Haskey_i +- vmovdqa \HK, \T5 +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^2<<1 mod poly +- vmovdqu \T5, HashKey_2(arg2) # [HashKey_2] = HashKey^2<<1 mod poly +- +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^3<<1 mod poly +- vmovdqu \T5, HashKey_3(arg2) +- +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^4<<1 mod poly +- vmovdqu \T5, HashKey_4(arg2) +- +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^5<<1 mod poly +- vmovdqu \T5, HashKey_5(arg2) +- +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^6<<1 mod poly +- vmovdqu \T5, HashKey_6(arg2) +- +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^7<<1 mod poly +- vmovdqu \T5, HashKey_7(arg2) +- +- GHASH_MUL_AVX2 \T5, \HK, \T1, \T3, \T4, \T6, \T2 # T5 = HashKey^8<<1 mod poly +- vmovdqu \T5, HashKey_8(arg2) +- +-.endm +- +-## if a = number of total plaintext bytes +-## b = floor(a/16) +-## num_initial_blocks = b mod 4# +-## encrypt the initial num_initial_blocks blocks and apply ghash on the ciphertext +-## r10, r11, r12, rax are clobbered +-## arg1, arg2, arg3, arg4 are used as pointers only, not modified +- +-.macro INITIAL_BLOCKS_AVX2 REP num_initial_blocks T1 T2 T3 T4 T5 CTR XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 T6 T_key ENC_DEC VER +- i = (8-\num_initial_blocks) +- setreg +- vmovdqu AadHash(arg2), reg_i +- +- # start AES for num_initial_blocks blocks +- vmovdqu CurCount(arg2), \CTR +- +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, reg_i +- vpshufb SHUF_MASK(%rip), reg_i, reg_i # perform a 16Byte swap +- i = (i+1) +- setreg +-.endr +- +- vmovdqa (arg1), \T_key +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vpxor \T_key, reg_i, reg_i +- i = (i+1) +- setreg +-.endr +- +- j = 1 +- setreg +-.rep \REP +- vmovdqa 16*j(arg1), \T_key +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vaesenc \T_key, reg_i, reg_i +- i = (i+1) +- setreg +-.endr +- +- j = (j+1) +- setreg +-.endr +- +- +- vmovdqa 16*j(arg1), \T_key +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vaesenclast \T_key, reg_i, reg_i +- i = (i+1) +- setreg +-.endr +- +- i = (9-\num_initial_blocks) +- setreg +-.rep \num_initial_blocks +- vmovdqu (arg4, %r11), \T1 +- vpxor \T1, reg_i, reg_i +- vmovdqu reg_i, (arg3 , %r11) # write back ciphertext for +- # num_initial_blocks blocks +- add $16, %r11 +-.if \ENC_DEC == DEC +- vmovdqa \T1, reg_i +-.endif +- vpshufb SHUF_MASK(%rip), reg_i, reg_i # prepare ciphertext for GHASH computations +- i = (i+1) +- setreg +-.endr +- +- +- i = (8-\num_initial_blocks) +- j = (9-\num_initial_blocks) +- setreg +- +-.rep \num_initial_blocks +- vpxor reg_i, reg_j, reg_j +- GHASH_MUL_AVX2 reg_j, \T2, \T1, \T3, \T4, \T5, \T6 # apply GHASH on num_initial_blocks blocks +- i = (i+1) +- j = (j+1) +- setreg +-.endr +- # XMM8 has the combined result here +- +- vmovdqa \XMM8, TMP1(%rsp) +- vmovdqa \XMM8, \T3 +- +- cmp $128, %r13 +- jl .L_initial_blocks_done\@ # no need for precomputed constants +- +-############################################################################### +-# Haskey_i_k holds XORed values of the low and high parts of the Haskey_i +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM1 +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM2 +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM3 +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM4 +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM5 +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM6 +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM7 +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- +- vpaddd ONE(%rip), \CTR, \CTR # INCR Y0 +- vmovdqa \CTR, \XMM8 +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +- +- vmovdqa (arg1), \T_key +- vpxor \T_key, \XMM1, \XMM1 +- vpxor \T_key, \XMM2, \XMM2 +- vpxor \T_key, \XMM3, \XMM3 +- vpxor \T_key, \XMM4, \XMM4 +- vpxor \T_key, \XMM5, \XMM5 +- vpxor \T_key, \XMM6, \XMM6 +- vpxor \T_key, \XMM7, \XMM7 +- vpxor \T_key, \XMM8, \XMM8 +- +- i = 1 +- setreg +-.rep \REP # do REP rounds +- vmovdqa 16*i(arg1), \T_key +- vaesenc \T_key, \XMM1, \XMM1 +- vaesenc \T_key, \XMM2, \XMM2 +- vaesenc \T_key, \XMM3, \XMM3 +- vaesenc \T_key, \XMM4, \XMM4 +- vaesenc \T_key, \XMM5, \XMM5 +- vaesenc \T_key, \XMM6, \XMM6 +- vaesenc \T_key, \XMM7, \XMM7 +- vaesenc \T_key, \XMM8, \XMM8 +- i = (i+1) +- setreg +-.endr +- +- +- vmovdqa 16*i(arg1), \T_key +- vaesenclast \T_key, \XMM1, \XMM1 +- vaesenclast \T_key, \XMM2, \XMM2 +- vaesenclast \T_key, \XMM3, \XMM3 +- vaesenclast \T_key, \XMM4, \XMM4 +- vaesenclast \T_key, \XMM5, \XMM5 +- vaesenclast \T_key, \XMM6, \XMM6 +- vaesenclast \T_key, \XMM7, \XMM7 +- vaesenclast \T_key, \XMM8, \XMM8 +- +- vmovdqu (arg4, %r11), \T1 +- vpxor \T1, \XMM1, \XMM1 +- vmovdqu \XMM1, (arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM1 +- .endif +- +- vmovdqu 16*1(arg4, %r11), \T1 +- vpxor \T1, \XMM2, \XMM2 +- vmovdqu \XMM2, 16*1(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM2 +- .endif +- +- vmovdqu 16*2(arg4, %r11), \T1 +- vpxor \T1, \XMM3, \XMM3 +- vmovdqu \XMM3, 16*2(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM3 +- .endif +- +- vmovdqu 16*3(arg4, %r11), \T1 +- vpxor \T1, \XMM4, \XMM4 +- vmovdqu \XMM4, 16*3(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM4 +- .endif +- +- vmovdqu 16*4(arg4, %r11), \T1 +- vpxor \T1, \XMM5, \XMM5 +- vmovdqu \XMM5, 16*4(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM5 +- .endif +- +- vmovdqu 16*5(arg4, %r11), \T1 +- vpxor \T1, \XMM6, \XMM6 +- vmovdqu \XMM6, 16*5(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM6 +- .endif +- +- vmovdqu 16*6(arg4, %r11), \T1 +- vpxor \T1, \XMM7, \XMM7 +- vmovdqu \XMM7, 16*6(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM7 +- .endif +- +- vmovdqu 16*7(arg4, %r11), \T1 +- vpxor \T1, \XMM8, \XMM8 +- vmovdqu \XMM8, 16*7(arg3 , %r11) +- .if \ENC_DEC == DEC +- vmovdqa \T1, \XMM8 +- .endif +- +- add $128, %r11 +- +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- vpxor TMP1(%rsp), \XMM1, \XMM1 # combine GHASHed value with +- # the corresponding ciphertext +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +- +-############################################################################### +- +-.L_initial_blocks_done\@: +- +- +-.endm +- +- +- +-# encrypt 8 blocks at a time +-# ghash the 8 previously encrypted ciphertext blocks +-# arg1, arg2, arg3, arg4 are used as pointers only, not modified +-# r11 is the data offset value +-.macro GHASH_8_ENCRYPT_8_PARALLEL_AVX2 REP T1 T2 T3 T4 T5 T6 CTR XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 T7 loop_idx ENC_DEC +- +- vmovdqa \XMM1, \T2 +- vmovdqa \XMM2, TMP2(%rsp) +- vmovdqa \XMM3, TMP3(%rsp) +- vmovdqa \XMM4, TMP4(%rsp) +- vmovdqa \XMM5, TMP5(%rsp) +- vmovdqa \XMM6, TMP6(%rsp) +- vmovdqa \XMM7, TMP7(%rsp) +- vmovdqa \XMM8, TMP8(%rsp) +- +-.if \loop_idx == in_order +- vpaddd ONE(%rip), \CTR, \XMM1 # INCR CNT +- vpaddd ONE(%rip), \XMM1, \XMM2 +- vpaddd ONE(%rip), \XMM2, \XMM3 +- vpaddd ONE(%rip), \XMM3, \XMM4 +- vpaddd ONE(%rip), \XMM4, \XMM5 +- vpaddd ONE(%rip), \XMM5, \XMM6 +- vpaddd ONE(%rip), \XMM6, \XMM7 +- vpaddd ONE(%rip), \XMM7, \XMM8 +- vmovdqa \XMM8, \CTR +- +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +-.else +- vpaddd ONEf(%rip), \CTR, \XMM1 # INCR CNT +- vpaddd ONEf(%rip), \XMM1, \XMM2 +- vpaddd ONEf(%rip), \XMM2, \XMM3 +- vpaddd ONEf(%rip), \XMM3, \XMM4 +- vpaddd ONEf(%rip), \XMM4, \XMM5 +- vpaddd ONEf(%rip), \XMM5, \XMM6 +- vpaddd ONEf(%rip), \XMM6, \XMM7 +- vpaddd ONEf(%rip), \XMM7, \XMM8 +- vmovdqa \XMM8, \CTR +-.endif +- +- +- ####################################################################### +- +- vmovdqu (arg1), \T1 +- vpxor \T1, \XMM1, \XMM1 +- vpxor \T1, \XMM2, \XMM2 +- vpxor \T1, \XMM3, \XMM3 +- vpxor \T1, \XMM4, \XMM4 +- vpxor \T1, \XMM5, \XMM5 +- vpxor \T1, \XMM6, \XMM6 +- vpxor \T1, \XMM7, \XMM7 +- vpxor \T1, \XMM8, \XMM8 +- +- ####################################################################### +- +- +- +- +- +- vmovdqu 16*1(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqu 16*2(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- +- ####################################################################### +- +- vmovdqu HashKey_8(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T2, \T4 # T4 = a1*b1 +- vpclmulqdq $0x00, \T5, \T2, \T7 # T7 = a0*b0 +- vpclmulqdq $0x01, \T5, \T2, \T6 # T6 = a1*b0 +- vpclmulqdq $0x10, \T5, \T2, \T5 # T5 = a0*b1 +- vpxor \T5, \T6, \T6 +- +- vmovdqu 16*3(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP2(%rsp), \T1 +- vmovdqu HashKey_7(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*4(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- ####################################################################### +- +- vmovdqa TMP3(%rsp), \T1 +- vmovdqu HashKey_6(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*5(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP4(%rsp), \T1 +- vmovdqu HashKey_5(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*6(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- +- vmovdqa TMP5(%rsp), \T1 +- vmovdqu HashKey_4(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*7(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP6(%rsp), \T1 +- vmovdqu HashKey_3(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vmovdqu 16*8(arg1), \T1 +- vaesenc \T1, \XMM1, \XMM1 +- vaesenc \T1, \XMM2, \XMM2 +- vaesenc \T1, \XMM3, \XMM3 +- vaesenc \T1, \XMM4, \XMM4 +- vaesenc \T1, \XMM5, \XMM5 +- vaesenc \T1, \XMM6, \XMM6 +- vaesenc \T1, \XMM7, \XMM7 +- vaesenc \T1, \XMM8, \XMM8 +- +- vmovdqa TMP7(%rsp), \T1 +- vmovdqu HashKey_2(arg2), \T5 +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T4 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- +- ####################################################################### +- +- vmovdqu 16*9(arg1), \T5 +- vaesenc \T5, \XMM1, \XMM1 +- vaesenc \T5, \XMM2, \XMM2 +- vaesenc \T5, \XMM3, \XMM3 +- vaesenc \T5, \XMM4, \XMM4 +- vaesenc \T5, \XMM5, \XMM5 +- vaesenc \T5, \XMM6, \XMM6 +- vaesenc \T5, \XMM7, \XMM7 +- vaesenc \T5, \XMM8, \XMM8 +- +- vmovdqa TMP8(%rsp), \T1 +- vmovdqu HashKey(arg2), \T5 +- +- vpclmulqdq $0x00, \T5, \T1, \T3 +- vpxor \T3, \T7, \T7 +- +- vpclmulqdq $0x01, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x10, \T5, \T1, \T3 +- vpxor \T3, \T6, \T6 +- +- vpclmulqdq $0x11, \T5, \T1, \T3 +- vpxor \T3, \T4, \T1 +- +- +- vmovdqu 16*10(arg1), \T5 +- +- i = 11 +- setreg +-.rep (\REP-9) +- vaesenc \T5, \XMM1, \XMM1 +- vaesenc \T5, \XMM2, \XMM2 +- vaesenc \T5, \XMM3, \XMM3 +- vaesenc \T5, \XMM4, \XMM4 +- vaesenc \T5, \XMM5, \XMM5 +- vaesenc \T5, \XMM6, \XMM6 +- vaesenc \T5, \XMM7, \XMM7 +- vaesenc \T5, \XMM8, \XMM8 +- +- vmovdqu 16*i(arg1), \T5 +- i = i + 1 +- setreg +-.endr +- +- i = 0 +- j = 1 +- setreg +-.rep 8 +- vpxor 16*i(arg4, %r11), \T5, \T2 +- .if \ENC_DEC == ENC +- vaesenclast \T2, reg_j, reg_j +- .else +- vaesenclast \T2, reg_j, \T3 +- vmovdqu 16*i(arg4, %r11), reg_j +- vmovdqu \T3, 16*i(arg3, %r11) +- .endif +- i = (i+1) +- j = (j+1) +- setreg +-.endr +- ####################################################################### +- +- +- vpslldq $8, \T6, \T3 # shift-L T3 2 DWs +- vpsrldq $8, \T6, \T6 # shift-R T2 2 DWs +- vpxor \T3, \T7, \T7 +- vpxor \T6, \T1, \T1 # accumulate the results in T1:T7 +- +- +- +- ####################################################################### +- #first phase of the reduction +- vmovdqa POLY2(%rip), \T3 +- +- vpclmulqdq $0x01, \T7, \T3, \T2 +- vpslldq $8, \T2, \T2 # shift-L xmm2 2 DWs +- +- vpxor \T2, \T7, \T7 # first phase of the reduction complete +- ####################################################################### +- .if \ENC_DEC == ENC +- vmovdqu \XMM1, 16*0(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM2, 16*1(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM3, 16*2(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM4, 16*3(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM5, 16*4(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM6, 16*5(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM7, 16*6(arg3,%r11) # Write to the Ciphertext buffer +- vmovdqu \XMM8, 16*7(arg3,%r11) # Write to the Ciphertext buffer +- .endif +- +- ####################################################################### +- #second phase of the reduction +- vpclmulqdq $0x00, \T7, \T3, \T2 +- vpsrldq $4, \T2, \T2 # shift-R xmm2 1 DW (Shift-R only 1-DW to obtain 2-DWs shift-R) +- +- vpclmulqdq $0x10, \T7, \T3, \T4 +- vpslldq $4, \T4, \T4 # shift-L xmm0 1 DW (Shift-L 1-DW to obtain result with no shifts) +- +- vpxor \T2, \T4, \T4 # second phase of the reduction complete +- ####################################################################### +- vpxor \T4, \T1, \T1 # the result is in T1 +- +- vpshufb SHUF_MASK(%rip), \XMM1, \XMM1 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM2, \XMM2 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM3, \XMM3 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM4, \XMM4 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM5, \XMM5 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM6, \XMM6 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM7, \XMM7 # perform a 16Byte swap +- vpshufb SHUF_MASK(%rip), \XMM8, \XMM8 # perform a 16Byte swap +- +- +- vpxor \T1, \XMM1, \XMM1 +- +- +- +-.endm +- +- +-# GHASH the last 4 ciphertext blocks. +-.macro GHASH_LAST_8_AVX2 T1 T2 T3 T4 T5 T6 T7 XMM1 XMM2 XMM3 XMM4 XMM5 XMM6 XMM7 XMM8 +- +- ## Karatsuba Method +- +- vmovdqu HashKey_8(arg2), \T5 +- +- vpshufd $0b01001110, \XMM1, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM1, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM1, \T6 +- vpclmulqdq $0x00, \T5, \XMM1, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \XMM1 +- +- ###################### +- +- vmovdqu HashKey_7(arg2), \T5 +- vpshufd $0b01001110, \XMM2, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM2, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM2, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM2, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vmovdqu HashKey_6(arg2), \T5 +- vpshufd $0b01001110, \XMM3, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM3, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM3, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM3, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vmovdqu HashKey_5(arg2), \T5 +- vpshufd $0b01001110, \XMM4, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM4, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM4, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM4, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vmovdqu HashKey_4(arg2), \T5 +- vpshufd $0b01001110, \XMM5, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM5, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM5, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM5, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vmovdqu HashKey_3(arg2), \T5 +- vpshufd $0b01001110, \XMM6, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM6, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM6, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM6, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vmovdqu HashKey_2(arg2), \T5 +- vpshufd $0b01001110, \XMM7, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM7, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM7, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM7, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- +- ###################### +- +- vmovdqu HashKey(arg2), \T5 +- vpshufd $0b01001110, \XMM8, \T2 +- vpshufd $0b01001110, \T5, \T3 +- vpxor \XMM8, \T2, \T2 +- vpxor \T5, \T3, \T3 +- +- vpclmulqdq $0x11, \T5, \XMM8, \T4 +- vpxor \T4, \T6, \T6 +- +- vpclmulqdq $0x00, \T5, \XMM8, \T4 +- vpxor \T4, \T7, \T7 +- +- vpclmulqdq $0x00, \T3, \T2, \T2 +- +- vpxor \T2, \XMM1, \XMM1 +- vpxor \T6, \XMM1, \XMM1 +- vpxor \T7, \XMM1, \T2 +- +- +- +- +- vpslldq $8, \T2, \T4 +- vpsrldq $8, \T2, \T2 +- +- vpxor \T4, \T7, \T7 +- vpxor \T2, \T6, \T6 # holds the result of the +- # accumulated carry-less multiplications +- +- ####################################################################### +- #first phase of the reduction +- vmovdqa POLY2(%rip), \T3 +- +- vpclmulqdq $0x01, \T7, \T3, \T2 +- vpslldq $8, \T2, \T2 # shift-L xmm2 2 DWs +- +- vpxor \T2, \T7, \T7 # first phase of the reduction complete +- ####################################################################### +- +- +- #second phase of the reduction +- vpclmulqdq $0x00, \T7, \T3, \T2 +- vpsrldq $4, \T2, \T2 # shift-R T2 1 DW (Shift-R only 1-DW to obtain 2-DWs shift-R) +- +- vpclmulqdq $0x10, \T7, \T3, \T4 +- vpslldq $4, \T4, \T4 # shift-L T4 1 DW (Shift-L 1-DW to obtain result with no shifts) +- +- vpxor \T2, \T4, \T4 # second phase of the reduction complete +- ####################################################################### +- vpxor \T4, \T6, \T6 # the result is in T6 +-.endm +- +- +- +-############################################################# +-#void aesni_gcm_init_avx_gen4 +-# (gcm_data *my_ctx_data, +-# gcm_context_data *data, +-# u8 *iv, /* Pre-counter block j0: 4 byte salt +-# (from Security Association) concatenated with 8 byte +-# Initialisation Vector (from IPSec ESP Payload) +-# concatenated with 0x00000001. 16-byte aligned pointer. */ +-# u8 *hash_subkey# /* H, the Hash sub key input. Data starts on a 16-byte boundary. */ +-# const u8 *aad, /* Additional Authentication Data (AAD)*/ +-# u64 aad_len) /* Length of AAD in bytes. With RFC4106 this is going to be 8 or 12 Bytes */ +-############################################################# +-SYM_FUNC_START(aesni_gcm_init_avx_gen4) +- FUNC_SAVE +- INIT GHASH_MUL_AVX2, PRECOMPUTE_AVX2 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_init_avx_gen4) +- +-############################################################################### +-#void aesni_gcm_enc_avx_gen4( +-# gcm_data *my_ctx_data, /* aligned to 16 Bytes */ +-# gcm_context_data *data, +-# u8 *out, /* Ciphertext output. Encrypt in-place is allowed. */ +-# const u8 *in, /* Plaintext input */ +-# u64 plaintext_len) /* Length of data in Bytes for encryption. */ +-############################################################################### +-SYM_FUNC_START(aesni_gcm_enc_update_avx_gen4) +- FUNC_SAVE +- mov keysize,%eax +- cmp $32, %eax +- je key_256_enc_update4 +- cmp $16, %eax +- je key_128_enc_update4 +- # must be 192 +- GCM_ENC_DEC INITIAL_BLOCKS_AVX2, GHASH_8_ENCRYPT_8_PARALLEL_AVX2, GHASH_LAST_8_AVX2, GHASH_MUL_AVX2, ENC, 11 +- FUNC_RESTORE +- RET +-key_128_enc_update4: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX2, GHASH_8_ENCRYPT_8_PARALLEL_AVX2, GHASH_LAST_8_AVX2, GHASH_MUL_AVX2, ENC, 9 +- FUNC_RESTORE +- RET +-key_256_enc_update4: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX2, GHASH_8_ENCRYPT_8_PARALLEL_AVX2, GHASH_LAST_8_AVX2, GHASH_MUL_AVX2, ENC, 13 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_enc_update_avx_gen4) +- +-############################################################################### +-#void aesni_gcm_dec_update_avx_gen4( +-# gcm_data *my_ctx_data, /* aligned to 16 Bytes */ +-# gcm_context_data *data, +-# u8 *out, /* Plaintext output. Decrypt in-place is allowed. */ +-# const u8 *in, /* Ciphertext input */ +-# u64 plaintext_len) /* Length of data in Bytes for encryption. */ +-############################################################################### +-SYM_FUNC_START(aesni_gcm_dec_update_avx_gen4) +- FUNC_SAVE +- mov keysize,%eax +- cmp $32, %eax +- je key_256_dec_update4 +- cmp $16, %eax +- je key_128_dec_update4 +- # must be 192 +- GCM_ENC_DEC INITIAL_BLOCKS_AVX2, GHASH_8_ENCRYPT_8_PARALLEL_AVX2, GHASH_LAST_8_AVX2, GHASH_MUL_AVX2, DEC, 11 +- FUNC_RESTORE +- RET +-key_128_dec_update4: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX2, GHASH_8_ENCRYPT_8_PARALLEL_AVX2, GHASH_LAST_8_AVX2, GHASH_MUL_AVX2, DEC, 9 +- FUNC_RESTORE +- RET +-key_256_dec_update4: +- GCM_ENC_DEC INITIAL_BLOCKS_AVX2, GHASH_8_ENCRYPT_8_PARALLEL_AVX2, GHASH_LAST_8_AVX2, GHASH_MUL_AVX2, DEC, 13 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_dec_update_avx_gen4) +- +-############################################################################### +-#void aesni_gcm_finalize_avx_gen4( +-# gcm_data *my_ctx_data, /* aligned to 16 Bytes */ +-# gcm_context_data *data, +-# u8 *auth_tag, /* Authenticated Tag output. */ +-# u64 auth_tag_len)# /* Authenticated Tag Length in bytes. +-# Valid values are 16 (most likely), 12 or 8. */ +-############################################################################### +-SYM_FUNC_START(aesni_gcm_finalize_avx_gen4) +- FUNC_SAVE +- mov keysize,%eax +- cmp $32, %eax +- je key_256_finalize4 +- cmp $16, %eax +- je key_128_finalize4 +- # must be 192 +- GCM_COMPLETE GHASH_MUL_AVX2, 11, arg3, arg4 +- FUNC_RESTORE +- RET +-key_128_finalize4: +- GCM_COMPLETE GHASH_MUL_AVX2, 9, arg3, arg4 +- FUNC_RESTORE +- RET +-key_256_finalize4: +- GCM_COMPLETE GHASH_MUL_AVX2, 13, arg3, arg4 +- FUNC_RESTORE +- RET +-SYM_FUNC_END(aesni_gcm_finalize_avx_gen4) +diff --git a/arch/x86/crypto/aesni-intel_glue.c b/arch/x86/crypto/aesni-intel_glue.c +index ef031655b2d3..cd37de5ec404 100644 +--- a/arch/x86/crypto/aesni-intel_glue.c ++++ b/arch/x86/crypto/aesni-intel_glue.c +@@ -1,7 +1,7 @@ + // SPDX-License-Identifier: GPL-2.0-or-later + /* +- * Support for Intel AES-NI instructions. This file contains glue +- * code, the real AES implementation is in intel-aes_asm.S. ++ * Support for AES-NI and VAES instructions. This file contains glue code. ++ * The real AES implementations are in aesni-intel_asm.S and other .S files. + * + * Copyright (C) 2008, Intel Corp. + * Author: Huang Ying +@@ -13,6 +13,8 @@ + * Tadeusz Struk (tadeusz.struk@intel.com) + * Aidan O'Mahony (aidan.o.mahony@intel.com) + * Copyright (c) 2010, Intel Corporation. ++ * ++ * Copyright 2024 Google LLC + */ + + #include +@@ -44,41 +46,11 @@ + #define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA) + #define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA) + +-/* This data is stored at the end of the crypto_tfm struct. +- * It's a type of per "session" data storage location. +- * This needs to be 16 byte aligned. +- */ +-struct aesni_rfc4106_gcm_ctx { +- u8 hash_subkey[16] AESNI_ALIGN_ATTR; +- struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR; +- u8 nonce[4]; +-}; +- +-struct generic_gcmaes_ctx { +- u8 hash_subkey[16] AESNI_ALIGN_ATTR; +- struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR; +-}; +- + struct aesni_xts_ctx { + struct crypto_aes_ctx tweak_ctx AESNI_ALIGN_ATTR; + struct crypto_aes_ctx crypt_ctx AESNI_ALIGN_ATTR; + }; + +-#define GCM_BLOCK_LEN 16 +- +-struct gcm_context_data { +- /* init, update and finalize context data */ +- u8 aad_hash[GCM_BLOCK_LEN]; +- u64 aad_length; +- u64 in_length; +- u8 partial_block_enc_key[GCM_BLOCK_LEN]; +- u8 orig_IV[GCM_BLOCK_LEN]; +- u8 current_counter[GCM_BLOCK_LEN]; +- u64 partial_block_len; +- u64 unused; +- u8 hash_keys[GCM_BLOCK_LEN * 16]; +-}; +- + static inline void *aes_align_addr(void *addr) + { + if (crypto_tfm_ctx_alignment() >= AESNI_ALIGN) +@@ -103,9 +75,6 @@ asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out, + asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out, + const u8 *in, unsigned int len, u8 *iv); + +-#define AVX_GEN2_OPTSIZE 640 +-#define AVX_GEN4_OPTSIZE 4096 +- + asmlinkage void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *out, + const u8 *in, unsigned int len, u8 *iv); + +@@ -118,23 +87,6 @@ asmlinkage void aesni_ctr_enc(struct crypto_aes_ctx *ctx, u8 *out, + const u8 *in, unsigned int len, u8 *iv); + DEFINE_STATIC_CALL(aesni_ctr_enc_tfm, aesni_ctr_enc); + +-/* Scatter / Gather routines, with args similar to above */ +-asmlinkage void aesni_gcm_init(void *ctx, +- struct gcm_context_data *gdata, +- u8 *iv, +- u8 *hash_subkey, const u8 *aad, +- unsigned long aad_len); +-asmlinkage void aesni_gcm_enc_update(void *ctx, +- struct gcm_context_data *gdata, u8 *out, +- const u8 *in, unsigned long plaintext_len); +-asmlinkage void aesni_gcm_dec_update(void *ctx, +- struct gcm_context_data *gdata, u8 *out, +- const u8 *in, +- unsigned long ciphertext_len); +-asmlinkage void aesni_gcm_finalize(void *ctx, +- struct gcm_context_data *gdata, +- u8 *auth_tag, unsigned long auth_tag_len); +- + asmlinkage void aes_ctr_enc_128_avx_by8(const u8 *in, u8 *iv, + void *keys, u8 *out, unsigned int num_bytes); + asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv, +@@ -154,67 +106,6 @@ asmlinkage void aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv, + asmlinkage void aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv, + const void *keys, u8 *out, unsigned int num_bytes, + unsigned int byte_ctr); +- +-/* +- * asmlinkage void aesni_gcm_init_avx_gen2() +- * gcm_data *my_ctx_data, context data +- * u8 *hash_subkey, the Hash sub key input. Data starts on a 16-byte boundary. +- */ +-asmlinkage void aesni_gcm_init_avx_gen2(void *my_ctx_data, +- struct gcm_context_data *gdata, +- u8 *iv, +- u8 *hash_subkey, +- const u8 *aad, +- unsigned long aad_len); +- +-asmlinkage void aesni_gcm_enc_update_avx_gen2(void *ctx, +- struct gcm_context_data *gdata, u8 *out, +- const u8 *in, unsigned long plaintext_len); +-asmlinkage void aesni_gcm_dec_update_avx_gen2(void *ctx, +- struct gcm_context_data *gdata, u8 *out, +- const u8 *in, +- unsigned long ciphertext_len); +-asmlinkage void aesni_gcm_finalize_avx_gen2(void *ctx, +- struct gcm_context_data *gdata, +- u8 *auth_tag, unsigned long auth_tag_len); +- +-/* +- * asmlinkage void aesni_gcm_init_avx_gen4() +- * gcm_data *my_ctx_data, context data +- * u8 *hash_subkey, the Hash sub key input. Data starts on a 16-byte boundary. +- */ +-asmlinkage void aesni_gcm_init_avx_gen4(void *my_ctx_data, +- struct gcm_context_data *gdata, +- u8 *iv, +- u8 *hash_subkey, +- const u8 *aad, +- unsigned long aad_len); +- +-asmlinkage void aesni_gcm_enc_update_avx_gen4(void *ctx, +- struct gcm_context_data *gdata, u8 *out, +- const u8 *in, unsigned long plaintext_len); +-asmlinkage void aesni_gcm_dec_update_avx_gen4(void *ctx, +- struct gcm_context_data *gdata, u8 *out, +- const u8 *in, +- unsigned long ciphertext_len); +-asmlinkage void aesni_gcm_finalize_avx_gen4(void *ctx, +- struct gcm_context_data *gdata, +- u8 *auth_tag, unsigned long auth_tag_len); +- +-static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx); +-static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx2); +- +-static inline struct +-aesni_rfc4106_gcm_ctx *aesni_rfc4106_gcm_ctx_get(struct crypto_aead *tfm) +-{ +- return aes_align_addr(crypto_aead_ctx(tfm)); +-} +- +-static inline struct +-generic_gcmaes_ctx *generic_gcmaes_ctx_get(struct crypto_aead *tfm) +-{ +- return aes_align_addr(crypto_aead_ctx(tfm)); +-} + #endif + + static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx) +@@ -588,280 +479,6 @@ static int xctr_crypt(struct skcipher_request *req) + } + return err; + } +- +-static int aes_gcm_derive_hash_subkey(const struct crypto_aes_ctx *aes_key, +- u8 hash_subkey[AES_BLOCK_SIZE]) +-{ +- static const u8 zeroes[AES_BLOCK_SIZE]; +- +- aes_encrypt(aes_key, hash_subkey, zeroes); +- return 0; +-} +- +-static int common_rfc4106_set_key(struct crypto_aead *aead, const u8 *key, +- unsigned int key_len) +-{ +- struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(aead); +- +- if (key_len < 4) +- return -EINVAL; +- +- /*Account for 4 byte nonce at the end.*/ +- key_len -= 4; +- +- memcpy(ctx->nonce, key + key_len, sizeof(ctx->nonce)); +- +- return aes_set_key_common(&ctx->aes_key_expanded, key, key_len) ?: +- aes_gcm_derive_hash_subkey(&ctx->aes_key_expanded, +- ctx->hash_subkey); +-} +- +-/* This is the Integrity Check Value (aka the authentication tag) length and can +- * be 8, 12 or 16 bytes long. */ +-static int common_rfc4106_set_authsize(struct crypto_aead *aead, +- unsigned int authsize) +-{ +- switch (authsize) { +- case 8: +- case 12: +- case 16: +- break; +- default: +- return -EINVAL; +- } +- +- return 0; +-} +- +-static int generic_gcmaes_set_authsize(struct crypto_aead *tfm, +- unsigned int authsize) +-{ +- switch (authsize) { +- case 4: +- case 8: +- case 12: +- case 13: +- case 14: +- case 15: +- case 16: +- break; +- default: +- return -EINVAL; +- } +- +- return 0; +-} +- +-static int gcmaes_crypt_by_sg(bool enc, struct aead_request *req, +- unsigned int assoclen, u8 *hash_subkey, +- u8 *iv, void *aes_ctx, u8 *auth_tag, +- unsigned long auth_tag_len) +-{ +- u8 databuf[sizeof(struct gcm_context_data) + (AESNI_ALIGN - 8)] __aligned(8); +- struct gcm_context_data *data = PTR_ALIGN((void *)databuf, AESNI_ALIGN); +- unsigned long left = req->cryptlen; +- struct scatter_walk assoc_sg_walk; +- struct skcipher_walk walk; +- bool do_avx, do_avx2; +- u8 *assocmem = NULL; +- u8 *assoc; +- int err; +- +- if (!enc) +- left -= auth_tag_len; +- +- do_avx = (left >= AVX_GEN2_OPTSIZE); +- do_avx2 = (left >= AVX_GEN4_OPTSIZE); +- +- /* Linearize assoc, if not already linear */ +- if (req->src->length >= assoclen && req->src->length) { +- scatterwalk_start(&assoc_sg_walk, req->src); +- assoc = scatterwalk_map(&assoc_sg_walk); +- } else { +- gfp_t flags = (req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP) ? +- GFP_KERNEL : GFP_ATOMIC; +- +- /* assoc can be any length, so must be on heap */ +- assocmem = kmalloc(assoclen, flags); +- if (unlikely(!assocmem)) +- return -ENOMEM; +- assoc = assocmem; +- +- scatterwalk_map_and_copy(assoc, req->src, 0, assoclen, 0); +- } +- +- kernel_fpu_begin(); +- if (static_branch_likely(&gcm_use_avx2) && do_avx2) +- aesni_gcm_init_avx_gen4(aes_ctx, data, iv, hash_subkey, assoc, +- assoclen); +- else if (static_branch_likely(&gcm_use_avx) && do_avx) +- aesni_gcm_init_avx_gen2(aes_ctx, data, iv, hash_subkey, assoc, +- assoclen); +- else +- aesni_gcm_init(aes_ctx, data, iv, hash_subkey, assoc, assoclen); +- kernel_fpu_end(); +- +- if (!assocmem) +- scatterwalk_unmap(assoc); +- else +- kfree(assocmem); +- +- err = enc ? skcipher_walk_aead_encrypt(&walk, req, false) +- : skcipher_walk_aead_decrypt(&walk, req, false); +- +- while (walk.nbytes > 0) { +- kernel_fpu_begin(); +- if (static_branch_likely(&gcm_use_avx2) && do_avx2) { +- if (enc) +- aesni_gcm_enc_update_avx_gen4(aes_ctx, data, +- walk.dst.virt.addr, +- walk.src.virt.addr, +- walk.nbytes); +- else +- aesni_gcm_dec_update_avx_gen4(aes_ctx, data, +- walk.dst.virt.addr, +- walk.src.virt.addr, +- walk.nbytes); +- } else if (static_branch_likely(&gcm_use_avx) && do_avx) { +- if (enc) +- aesni_gcm_enc_update_avx_gen2(aes_ctx, data, +- walk.dst.virt.addr, +- walk.src.virt.addr, +- walk.nbytes); +- else +- aesni_gcm_dec_update_avx_gen2(aes_ctx, data, +- walk.dst.virt.addr, +- walk.src.virt.addr, +- walk.nbytes); +- } else if (enc) { +- aesni_gcm_enc_update(aes_ctx, data, walk.dst.virt.addr, +- walk.src.virt.addr, walk.nbytes); +- } else { +- aesni_gcm_dec_update(aes_ctx, data, walk.dst.virt.addr, +- walk.src.virt.addr, walk.nbytes); +- } +- kernel_fpu_end(); +- +- err = skcipher_walk_done(&walk, 0); +- } +- +- if (err) +- return err; +- +- kernel_fpu_begin(); +- if (static_branch_likely(&gcm_use_avx2) && do_avx2) +- aesni_gcm_finalize_avx_gen4(aes_ctx, data, auth_tag, +- auth_tag_len); +- else if (static_branch_likely(&gcm_use_avx) && do_avx) +- aesni_gcm_finalize_avx_gen2(aes_ctx, data, auth_tag, +- auth_tag_len); +- else +- aesni_gcm_finalize(aes_ctx, data, auth_tag, auth_tag_len); +- kernel_fpu_end(); +- +- return 0; +-} +- +-static int gcmaes_encrypt(struct aead_request *req, unsigned int assoclen, +- u8 *hash_subkey, u8 *iv, void *aes_ctx) +-{ +- struct crypto_aead *tfm = crypto_aead_reqtfm(req); +- unsigned long auth_tag_len = crypto_aead_authsize(tfm); +- u8 auth_tag[16]; +- int err; +- +- err = gcmaes_crypt_by_sg(true, req, assoclen, hash_subkey, iv, aes_ctx, +- auth_tag, auth_tag_len); +- if (err) +- return err; +- +- scatterwalk_map_and_copy(auth_tag, req->dst, +- req->assoclen + req->cryptlen, +- auth_tag_len, 1); +- return 0; +-} +- +-static int gcmaes_decrypt(struct aead_request *req, unsigned int assoclen, +- u8 *hash_subkey, u8 *iv, void *aes_ctx) +-{ +- struct crypto_aead *tfm = crypto_aead_reqtfm(req); +- unsigned long auth_tag_len = crypto_aead_authsize(tfm); +- u8 auth_tag_msg[16]; +- u8 auth_tag[16]; +- int err; +- +- err = gcmaes_crypt_by_sg(false, req, assoclen, hash_subkey, iv, aes_ctx, +- auth_tag, auth_tag_len); +- if (err) +- return err; +- +- /* Copy out original auth_tag */ +- scatterwalk_map_and_copy(auth_tag_msg, req->src, +- req->assoclen + req->cryptlen - auth_tag_len, +- auth_tag_len, 0); +- +- /* Compare generated tag with passed in tag. */ +- if (crypto_memneq(auth_tag_msg, auth_tag, auth_tag_len)) { +- memzero_explicit(auth_tag, sizeof(auth_tag)); +- return -EBADMSG; +- } +- return 0; +-} +- +-static int helper_rfc4106_encrypt(struct aead_request *req) +-{ +- struct crypto_aead *tfm = crypto_aead_reqtfm(req); +- struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm); +- void *aes_ctx = &(ctx->aes_key_expanded); +- u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8); +- u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN); +- unsigned int i; +- __be32 counter = cpu_to_be32(1); +- +- /* Assuming we are supporting rfc4106 64-bit extended */ +- /* sequence numbers We need to have the AAD length equal */ +- /* to 16 or 20 bytes */ +- if (unlikely(req->assoclen != 16 && req->assoclen != 20)) +- return -EINVAL; +- +- /* IV below built */ +- for (i = 0; i < 4; i++) +- *(iv+i) = ctx->nonce[i]; +- for (i = 0; i < 8; i++) +- *(iv+4+i) = req->iv[i]; +- *((__be32 *)(iv+12)) = counter; +- +- return gcmaes_encrypt(req, req->assoclen - 8, ctx->hash_subkey, iv, +- aes_ctx); +-} +- +-static int helper_rfc4106_decrypt(struct aead_request *req) +-{ +- __be32 counter = cpu_to_be32(1); +- struct crypto_aead *tfm = crypto_aead_reqtfm(req); +- struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm); +- void *aes_ctx = &(ctx->aes_key_expanded); +- u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8); +- u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN); +- unsigned int i; +- +- if (unlikely(req->assoclen != 16 && req->assoclen != 20)) +- return -EINVAL; +- +- /* Assuming we are supporting rfc4106 64-bit extended */ +- /* sequence numbers We need to have the AAD length */ +- /* equal to 16 or 20 bytes */ +- +- /* IV below built */ +- for (i = 0; i < 4; i++) +- *(iv+i) = ctx->nonce[i]; +- for (i = 0; i < 8; i++) +- *(iv+4+i) = req->iv[i]; +- *((__be32 *)(iv+12)) = counter; +- +- return gcmaes_decrypt(req, req->assoclen - 8, ctx->hash_subkey, iv, +- aes_ctx); +-} + #endif + + static int xts_setkey_aesni(struct crypto_skcipher *tfm, const u8 *key, +@@ -1216,11 +833,717 @@ DEFINE_XTS_ALG(vaes_avx10_256, "xts-aes-vaes-avx10_256", 700); + DEFINE_XTS_ALG(vaes_avx10_512, "xts-aes-vaes-avx10_512", 800); + #endif + ++/* The common part of the x86_64 AES-GCM key struct */ ++struct aes_gcm_key { ++ /* Expanded AES key and the AES key length in bytes */ ++ struct crypto_aes_ctx aes_key; ++ ++ /* RFC4106 nonce (used only by the rfc4106 algorithms) */ ++ u32 rfc4106_nonce; ++}; ++ ++/* Key struct used by the AES-NI implementations of AES-GCM */ ++struct aes_gcm_key_aesni { ++ /* ++ * Common part of the key. The assembly code requires 16-byte alignment ++ * for the round keys; we get this by them being located at the start of ++ * the struct and the whole struct being 16-byte aligned. ++ */ ++ struct aes_gcm_key base; ++ ++ /* ++ * Powers of the hash key H^8 through H^1. These are 128-bit values. ++ * They all have an extra factor of x^-1 and are byte-reversed. 16-byte ++ * alignment is required by the assembly code. ++ */ ++ u64 h_powers[8][2] __aligned(16); ++ ++ /* ++ * h_powers_xored[i] contains the two 64-bit halves of h_powers[i] XOR'd ++ * together. It's used for Karatsuba multiplication. 16-byte alignment ++ * is required by the assembly code. ++ */ ++ u64 h_powers_xored[8] __aligned(16); ++ ++ /* ++ * H^1 times x^64 (and also the usual extra factor of x^-1). 16-byte ++ * alignment is required by the assembly code. ++ */ ++ u64 h_times_x64[2] __aligned(16); ++}; ++#define AES_GCM_KEY_AESNI(key) \ ++ container_of((key), struct aes_gcm_key_aesni, base) ++#define AES_GCM_KEY_AESNI_SIZE \ ++ (sizeof(struct aes_gcm_key_aesni) + (15 & ~(CRYPTO_MINALIGN - 1))) ++ ++/* Key struct used by the VAES + AVX10 implementations of AES-GCM */ ++struct aes_gcm_key_avx10 { ++ /* ++ * Common part of the key. The assembly code prefers 16-byte alignment ++ * for the round keys; we get this by them being located at the start of ++ * the struct and the whole struct being 64-byte aligned. ++ */ ++ struct aes_gcm_key base; ++ ++ /* ++ * Powers of the hash key H^16 through H^1. These are 128-bit values. ++ * They all have an extra factor of x^-1 and are byte-reversed. This ++ * array is aligned to a 64-byte boundary to make it naturally aligned ++ * for 512-bit loads, which can improve performance. (The assembly code ++ * doesn't *need* the alignment; this is just an optimization.) ++ */ ++ u64 h_powers[16][2] __aligned(64); ++ ++ /* Three padding blocks required by the assembly code */ ++ u64 padding[3][2]; ++}; ++#define AES_GCM_KEY_AVX10(key) \ ++ container_of((key), struct aes_gcm_key_avx10, base) ++#define AES_GCM_KEY_AVX10_SIZE \ ++ (sizeof(struct aes_gcm_key_avx10) + (63 & ~(CRYPTO_MINALIGN - 1))) ++ ++/* ++ * These flags are passed to the AES-GCM helper functions to specify the ++ * specific version of AES-GCM (RFC4106 or not), whether it's encryption or ++ * decryption, and which assembly functions should be called. Assembly ++ * functions are selected using flags instead of function pointers to avoid ++ * indirect calls (which are very expensive on x86) regardless of inlining. ++ */ ++#define FLAG_RFC4106 BIT(0) ++#define FLAG_ENC BIT(1) ++#define FLAG_AVX BIT(2) ++#if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ) ++# define FLAG_AVX10_256 BIT(3) ++# define FLAG_AVX10_512 BIT(4) ++#else ++ /* ++ * This should cause all calls to the AVX10 assembly functions to be ++ * optimized out, avoiding the need to ifdef each call individually. ++ */ ++# define FLAG_AVX10_256 0 ++# define FLAG_AVX10_512 0 ++#endif ++ ++static inline struct aes_gcm_key * ++aes_gcm_key_get(struct crypto_aead *tfm, int flags) ++{ ++ if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) ++ return PTR_ALIGN(crypto_aead_ctx(tfm), 64); ++ else ++ return PTR_ALIGN(crypto_aead_ctx(tfm), 16); ++} ++ ++asmlinkage void ++aes_gcm_precompute_aesni(struct aes_gcm_key_aesni *key); ++asmlinkage void ++aes_gcm_precompute_aesni_avx(struct aes_gcm_key_aesni *key); ++asmlinkage void ++aes_gcm_precompute_vaes_avx10_256(struct aes_gcm_key_avx10 *key); ++asmlinkage void ++aes_gcm_precompute_vaes_avx10_512(struct aes_gcm_key_avx10 *key); ++ ++static void aes_gcm_precompute(struct aes_gcm_key *key, int flags) ++{ ++ /* ++ * To make things a bit easier on the assembly side, the AVX10 ++ * implementations use the same key format. Therefore, a single ++ * function using 256-bit vectors would suffice here. However, it's ++ * straightforward to provide a 512-bit one because of how the assembly ++ * code is structured, and it works nicely because the total size of the ++ * key powers is a multiple of 512 bits. So we take advantage of that. ++ * ++ * A similar situation applies to the AES-NI implementations. ++ */ ++ if (flags & FLAG_AVX10_512) ++ aes_gcm_precompute_vaes_avx10_512(AES_GCM_KEY_AVX10(key)); ++ else if (flags & FLAG_AVX10_256) ++ aes_gcm_precompute_vaes_avx10_256(AES_GCM_KEY_AVX10(key)); ++ else if (flags & FLAG_AVX) ++ aes_gcm_precompute_aesni_avx(AES_GCM_KEY_AESNI(key)); ++ else ++ aes_gcm_precompute_aesni(AES_GCM_KEY_AESNI(key)); ++} ++ ++asmlinkage void ++aes_gcm_aad_update_aesni(const struct aes_gcm_key_aesni *key, ++ u8 ghash_acc[16], const u8 *aad, int aadlen); ++asmlinkage void ++aes_gcm_aad_update_aesni_avx(const struct aes_gcm_key_aesni *key, ++ u8 ghash_acc[16], const u8 *aad, int aadlen); ++asmlinkage void ++aes_gcm_aad_update_vaes_avx10(const struct aes_gcm_key_avx10 *key, ++ u8 ghash_acc[16], const u8 *aad, int aadlen); ++ ++static void aes_gcm_aad_update(const struct aes_gcm_key *key, u8 ghash_acc[16], ++ const u8 *aad, int aadlen, int flags) ++{ ++ if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) ++ aes_gcm_aad_update_vaes_avx10(AES_GCM_KEY_AVX10(key), ghash_acc, ++ aad, aadlen); ++ else if (flags & FLAG_AVX) ++ aes_gcm_aad_update_aesni_avx(AES_GCM_KEY_AESNI(key), ghash_acc, ++ aad, aadlen); ++ else ++ aes_gcm_aad_update_aesni(AES_GCM_KEY_AESNI(key), ghash_acc, ++ aad, aadlen); ++} ++ ++asmlinkage void ++aes_gcm_enc_update_aesni(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++asmlinkage void ++aes_gcm_enc_update_aesni_avx(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++asmlinkage void ++aes_gcm_enc_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++asmlinkage void ++aes_gcm_enc_update_vaes_avx10_512(const struct aes_gcm_key_avx10 *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++ ++asmlinkage void ++aes_gcm_dec_update_aesni(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++asmlinkage void ++aes_gcm_dec_update_aesni_avx(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++asmlinkage void ++aes_gcm_dec_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++asmlinkage void ++aes_gcm_dec_update_vaes_avx10_512(const struct aes_gcm_key_avx10 *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen); ++ ++/* __always_inline to optimize out the branches based on @flags */ ++static __always_inline void ++aes_gcm_update(const struct aes_gcm_key *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ const u8 *src, u8 *dst, int datalen, int flags) ++{ ++ if (flags & FLAG_ENC) { ++ if (flags & FLAG_AVX10_512) ++ aes_gcm_enc_update_vaes_avx10_512(AES_GCM_KEY_AVX10(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ else if (flags & FLAG_AVX10_256) ++ aes_gcm_enc_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ else if (flags & FLAG_AVX) ++ aes_gcm_enc_update_aesni_avx(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ else ++ aes_gcm_enc_update_aesni(AES_GCM_KEY_AESNI(key), le_ctr, ++ ghash_acc, src, dst, datalen); ++ } else { ++ if (flags & FLAG_AVX10_512) ++ aes_gcm_dec_update_vaes_avx10_512(AES_GCM_KEY_AVX10(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ else if (flags & FLAG_AVX10_256) ++ aes_gcm_dec_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ else if (flags & FLAG_AVX) ++ aes_gcm_dec_update_aesni_avx(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ else ++ aes_gcm_dec_update_aesni(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ src, dst, datalen); ++ } ++} ++ ++asmlinkage void ++aes_gcm_enc_final_aesni(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen); ++asmlinkage void ++aes_gcm_enc_final_aesni_avx(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen); ++asmlinkage void ++aes_gcm_enc_final_vaes_avx10(const struct aes_gcm_key_avx10 *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen); ++ ++/* __always_inline to optimize out the branches based on @flags */ ++static __always_inline void ++aes_gcm_enc_final(const struct aes_gcm_key *key, ++ const u32 le_ctr[4], u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen, int flags) ++{ ++ if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) ++ aes_gcm_enc_final_vaes_avx10(AES_GCM_KEY_AVX10(key), ++ le_ctr, ghash_acc, ++ total_aadlen, total_datalen); ++ else if (flags & FLAG_AVX) ++ aes_gcm_enc_final_aesni_avx(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ total_aadlen, total_datalen); ++ else ++ aes_gcm_enc_final_aesni(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ total_aadlen, total_datalen); ++} ++ ++asmlinkage bool __must_check ++aes_gcm_dec_final_aesni(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], const u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen, ++ const u8 tag[16], int taglen); ++asmlinkage bool __must_check ++aes_gcm_dec_final_aesni_avx(const struct aes_gcm_key_aesni *key, ++ const u32 le_ctr[4], const u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen, ++ const u8 tag[16], int taglen); ++asmlinkage bool __must_check ++aes_gcm_dec_final_vaes_avx10(const struct aes_gcm_key_avx10 *key, ++ const u32 le_ctr[4], const u8 ghash_acc[16], ++ u64 total_aadlen, u64 total_datalen, ++ const u8 tag[16], int taglen); ++ ++/* __always_inline to optimize out the branches based on @flags */ ++static __always_inline bool __must_check ++aes_gcm_dec_final(const struct aes_gcm_key *key, const u32 le_ctr[4], ++ u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen, ++ u8 tag[16], int taglen, int flags) ++{ ++ if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) ++ return aes_gcm_dec_final_vaes_avx10(AES_GCM_KEY_AVX10(key), ++ le_ctr, ghash_acc, ++ total_aadlen, total_datalen, ++ tag, taglen); ++ else if (flags & FLAG_AVX) ++ return aes_gcm_dec_final_aesni_avx(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ total_aadlen, total_datalen, ++ tag, taglen); ++ else ++ return aes_gcm_dec_final_aesni(AES_GCM_KEY_AESNI(key), ++ le_ctr, ghash_acc, ++ total_aadlen, total_datalen, ++ tag, taglen); ++} ++ ++/* ++ * This is the Integrity Check Value (aka the authentication tag) length and can ++ * be 8, 12 or 16 bytes long. ++ */ ++static int common_rfc4106_set_authsize(struct crypto_aead *aead, ++ unsigned int authsize) ++{ ++ switch (authsize) { ++ case 8: ++ case 12: ++ case 16: ++ break; ++ default: ++ return -EINVAL; ++ } ++ ++ return 0; ++} ++ ++static int generic_gcmaes_set_authsize(struct crypto_aead *tfm, ++ unsigned int authsize) ++{ ++ switch (authsize) { ++ case 4: ++ case 8: ++ case 12: ++ case 13: ++ case 14: ++ case 15: ++ case 16: ++ break; ++ default: ++ return -EINVAL; ++ } ++ ++ return 0; ++} ++ ++/* ++ * This is the setkey function for the x86_64 implementations of AES-GCM. It ++ * saves the RFC4106 nonce if applicable, expands the AES key, and precomputes ++ * powers of the hash key. ++ * ++ * To comply with the crypto_aead API, this has to be usable in no-SIMD context. ++ * For that reason, this function includes a portable C implementation of the ++ * needed logic. However, the portable C implementation is very slow, taking ++ * about the same time as encrypting 37 KB of data. To be ready for users that ++ * may set a key even somewhat frequently, we therefore also include a SIMD ++ * assembly implementation, expanding the AES key using AES-NI and precomputing ++ * the hash key powers using PCLMULQDQ or VPCLMULQDQ. ++ */ ++static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key, ++ unsigned int keylen, int flags) ++{ ++ struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags); ++ int err; ++ ++ if (flags & FLAG_RFC4106) { ++ if (keylen < 4) ++ return -EINVAL; ++ keylen -= 4; ++ key->rfc4106_nonce = get_unaligned_be32(raw_key + keylen); ++ } ++ ++ /* The assembly code assumes the following offsets. */ ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, base.aes_key.key_enc) != 0); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, base.aes_key.key_length) != 480); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers) != 496); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers_xored) != 624); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_times_x64) != 688); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_enc) != 0); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_length) != 480); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, h_powers) != 512); ++ BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, padding) != 768); ++ ++ if (likely(crypto_simd_usable())) { ++ err = aes_check_keylen(keylen); ++ if (err) ++ return err; ++ kernel_fpu_begin(); ++ aesni_set_key(&key->aes_key, raw_key, keylen); ++ aes_gcm_precompute(key, flags); ++ kernel_fpu_end(); ++ } else { ++ static const u8 x_to_the_minus1[16] __aligned(__alignof__(be128)) = { ++ [0] = 0xc2, [15] = 1 ++ }; ++ static const u8 x_to_the_63[16] __aligned(__alignof__(be128)) = { ++ [7] = 1, ++ }; ++ be128 h1 = {}; ++ be128 h; ++ int i; ++ ++ err = aes_expandkey(&key->aes_key, raw_key, keylen); ++ if (err) ++ return err; ++ ++ /* Encrypt the all-zeroes block to get the hash key H^1 */ ++ aes_encrypt(&key->aes_key, (u8 *)&h1, (u8 *)&h1); ++ ++ /* Compute H^1 * x^-1 */ ++ h = h1; ++ gf128mul_lle(&h, (const be128 *)x_to_the_minus1); ++ ++ /* Compute the needed key powers */ ++ if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) { ++ struct aes_gcm_key_avx10 *k = AES_GCM_KEY_AVX10(key); ++ ++ for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) { ++ k->h_powers[i][0] = be64_to_cpu(h.b); ++ k->h_powers[i][1] = be64_to_cpu(h.a); ++ gf128mul_lle(&h, &h1); ++ } ++ memset(k->padding, 0, sizeof(k->padding)); ++ } else { ++ struct aes_gcm_key_aesni *k = AES_GCM_KEY_AESNI(key); ++ ++ for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) { ++ k->h_powers[i][0] = be64_to_cpu(h.b); ++ k->h_powers[i][1] = be64_to_cpu(h.a); ++ k->h_powers_xored[i] = k->h_powers[i][0] ^ ++ k->h_powers[i][1]; ++ gf128mul_lle(&h, &h1); ++ } ++ gf128mul_lle(&h1, (const be128 *)x_to_the_63); ++ k->h_times_x64[0] = be64_to_cpu(h1.b); ++ k->h_times_x64[1] = be64_to_cpu(h1.a); ++ } ++ } ++ return 0; ++} ++ ++/* ++ * Initialize @ghash_acc, then pass all @assoclen bytes of associated data ++ * (a.k.a. additional authenticated data) from @sg_src through the GHASH update ++ * assembly function. kernel_fpu_begin() must have already been called. ++ */ ++static void gcm_process_assoc(const struct aes_gcm_key *key, u8 ghash_acc[16], ++ struct scatterlist *sg_src, unsigned int assoclen, ++ int flags) ++{ ++ struct scatter_walk walk; ++ /* ++ * The assembly function requires that the length of any non-last ++ * segment of associated data be a multiple of 16 bytes, so this ++ * function does the buffering needed to achieve that. ++ */ ++ unsigned int pos = 0; ++ u8 buf[16]; ++ ++ memset(ghash_acc, 0, 16); ++ scatterwalk_start(&walk, sg_src); ++ ++ while (assoclen) { ++ unsigned int len_this_page = scatterwalk_clamp(&walk, assoclen); ++ void *mapped = scatterwalk_map(&walk); ++ const void *src = mapped; ++ unsigned int len; ++ ++ assoclen -= len_this_page; ++ scatterwalk_advance(&walk, len_this_page); ++ if (unlikely(pos)) { ++ len = min(len_this_page, 16 - pos); ++ memcpy(&buf[pos], src, len); ++ pos += len; ++ src += len; ++ len_this_page -= len; ++ if (pos < 16) ++ goto next; ++ aes_gcm_aad_update(key, ghash_acc, buf, 16, flags); ++ pos = 0; ++ } ++ len = len_this_page; ++ if (unlikely(assoclen)) /* Not the last segment yet? */ ++ len = round_down(len, 16); ++ aes_gcm_aad_update(key, ghash_acc, src, len, flags); ++ src += len; ++ len_this_page -= len; ++ if (unlikely(len_this_page)) { ++ memcpy(buf, src, len_this_page); ++ pos = len_this_page; ++ } ++next: ++ scatterwalk_unmap(mapped); ++ scatterwalk_pagedone(&walk, 0, assoclen); ++ if (need_resched()) { ++ kernel_fpu_end(); ++ kernel_fpu_begin(); ++ } ++ } ++ if (unlikely(pos)) ++ aes_gcm_aad_update(key, ghash_acc, buf, pos, flags); ++} ++ ++ ++/* __always_inline to optimize out the branches based on @flags */ ++static __always_inline int ++gcm_crypt(struct aead_request *req, int flags) ++{ ++ struct crypto_aead *tfm = crypto_aead_reqtfm(req); ++ const struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags); ++ unsigned int assoclen = req->assoclen; ++ struct skcipher_walk walk; ++ unsigned int nbytes; ++ u8 ghash_acc[16]; /* GHASH accumulator */ ++ u32 le_ctr[4]; /* Counter in little-endian format */ ++ int taglen; ++ int err; ++ ++ /* Initialize the counter and determine the associated data length. */ ++ le_ctr[0] = 2; ++ if (flags & FLAG_RFC4106) { ++ if (unlikely(assoclen != 16 && assoclen != 20)) ++ return -EINVAL; ++ assoclen -= 8; ++ le_ctr[1] = get_unaligned_be32(req->iv + 4); ++ le_ctr[2] = get_unaligned_be32(req->iv + 0); ++ le_ctr[3] = key->rfc4106_nonce; /* already byte-swapped */ ++ } else { ++ le_ctr[1] = get_unaligned_be32(req->iv + 8); ++ le_ctr[2] = get_unaligned_be32(req->iv + 4); ++ le_ctr[3] = get_unaligned_be32(req->iv + 0); ++ } ++ ++ /* Begin walking through the plaintext or ciphertext. */ ++ if (flags & FLAG_ENC) ++ err = skcipher_walk_aead_encrypt(&walk, req, false); ++ else ++ err = skcipher_walk_aead_decrypt(&walk, req, false); ++ ++ /* ++ * Since the AES-GCM assembly code requires that at least three assembly ++ * functions be called to process any message (this is needed to support ++ * incremental updates cleanly), to reduce overhead we try to do all ++ * three calls in the same kernel FPU section if possible. We close the ++ * section and start a new one if there are multiple data segments or if ++ * rescheduling is needed while processing the associated data. ++ */ ++ kernel_fpu_begin(); ++ ++ /* Pass the associated data through GHASH. */ ++ gcm_process_assoc(key, ghash_acc, req->src, assoclen, flags); ++ ++ /* En/decrypt the data and pass the ciphertext through GHASH. */ ++ while ((nbytes = walk.nbytes) != 0) { ++ if (unlikely(nbytes < walk.total)) { ++ /* ++ * Non-last segment. In this case, the assembly ++ * function requires that the length be a multiple of 16 ++ * (AES_BLOCK_SIZE) bytes. The needed buffering of up ++ * to 16 bytes is handled by the skcipher_walk. Here we ++ * just need to round down to a multiple of 16. ++ */ ++ nbytes = round_down(nbytes, AES_BLOCK_SIZE); ++ aes_gcm_update(key, le_ctr, ghash_acc, ++ walk.src.virt.addr, walk.dst.virt.addr, ++ nbytes, flags); ++ le_ctr[0] += nbytes / AES_BLOCK_SIZE; ++ kernel_fpu_end(); ++ err = skcipher_walk_done(&walk, walk.nbytes - nbytes); ++ kernel_fpu_begin(); ++ } else { ++ /* Last segment: process all remaining data. */ ++ aes_gcm_update(key, le_ctr, ghash_acc, ++ walk.src.virt.addr, walk.dst.virt.addr, ++ nbytes, flags); ++ err = skcipher_walk_done(&walk, 0); ++ /* ++ * The low word of the counter isn't used by the ++ * finalize, so there's no need to increment it here. ++ */ ++ } ++ } ++ if (err) ++ goto out; ++ ++ /* Finalize */ ++ taglen = crypto_aead_authsize(tfm); ++ if (flags & FLAG_ENC) { ++ /* Finish computing the auth tag. */ ++ aes_gcm_enc_final(key, le_ctr, ghash_acc, assoclen, ++ req->cryptlen, flags); ++ ++ /* Store the computed auth tag in the dst scatterlist. */ ++ scatterwalk_map_and_copy(ghash_acc, req->dst, req->assoclen + ++ req->cryptlen, taglen, 1); ++ } else { ++ unsigned int datalen = req->cryptlen - taglen; ++ u8 tag[16]; ++ ++ /* Get the transmitted auth tag from the src scatterlist. */ ++ scatterwalk_map_and_copy(tag, req->src, req->assoclen + datalen, ++ taglen, 0); ++ /* ++ * Finish computing the auth tag and compare it to the ++ * transmitted one. The assembly function does the actual tag ++ * comparison. Here, just check the boolean result. ++ */ ++ if (!aes_gcm_dec_final(key, le_ctr, ghash_acc, assoclen, ++ datalen, tag, taglen, flags)) ++ err = -EBADMSG; ++ } ++out: ++ kernel_fpu_end(); ++ return err; ++} ++ ++#define DEFINE_GCM_ALGS(suffix, flags, generic_driver_name, rfc_driver_name, \ ++ ctxsize, priority) \ ++ \ ++static int gcm_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key, \ ++ unsigned int keylen) \ ++{ \ ++ return gcm_setkey(tfm, raw_key, keylen, (flags)); \ ++} \ ++ \ ++static int gcm_encrypt_##suffix(struct aead_request *req) \ ++{ \ ++ return gcm_crypt(req, (flags) | FLAG_ENC); \ ++} \ ++ \ ++static int gcm_decrypt_##suffix(struct aead_request *req) \ ++{ \ ++ return gcm_crypt(req, (flags)); \ ++} \ ++ \ ++static int rfc4106_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key, \ ++ unsigned int keylen) \ ++{ \ ++ return gcm_setkey(tfm, raw_key, keylen, (flags) | FLAG_RFC4106); \ ++} \ ++ \ ++static int rfc4106_encrypt_##suffix(struct aead_request *req) \ ++{ \ ++ return gcm_crypt(req, (flags) | FLAG_RFC4106 | FLAG_ENC); \ ++} \ ++ \ ++static int rfc4106_decrypt_##suffix(struct aead_request *req) \ ++{ \ ++ return gcm_crypt(req, (flags) | FLAG_RFC4106); \ ++} \ ++ \ ++static struct aead_alg aes_gcm_algs_##suffix[] = { { \ ++ .setkey = gcm_setkey_##suffix, \ ++ .setauthsize = generic_gcmaes_set_authsize, \ ++ .encrypt = gcm_encrypt_##suffix, \ ++ .decrypt = gcm_decrypt_##suffix, \ ++ .ivsize = GCM_AES_IV_SIZE, \ ++ .chunksize = AES_BLOCK_SIZE, \ ++ .maxauthsize = 16, \ ++ .base = { \ ++ .cra_name = "__gcm(aes)", \ ++ .cra_driver_name = "__" generic_driver_name, \ ++ .cra_priority = (priority), \ ++ .cra_flags = CRYPTO_ALG_INTERNAL, \ ++ .cra_blocksize = 1, \ ++ .cra_ctxsize = (ctxsize), \ ++ .cra_module = THIS_MODULE, \ ++ }, \ ++}, { \ ++ .setkey = rfc4106_setkey_##suffix, \ ++ .setauthsize = common_rfc4106_set_authsize, \ ++ .encrypt = rfc4106_encrypt_##suffix, \ ++ .decrypt = rfc4106_decrypt_##suffix, \ ++ .ivsize = GCM_RFC4106_IV_SIZE, \ ++ .chunksize = AES_BLOCK_SIZE, \ ++ .maxauthsize = 16, \ ++ .base = { \ ++ .cra_name = "__rfc4106(gcm(aes))", \ ++ .cra_driver_name = "__" rfc_driver_name, \ ++ .cra_priority = (priority), \ ++ .cra_flags = CRYPTO_ALG_INTERNAL, \ ++ .cra_blocksize = 1, \ ++ .cra_ctxsize = (ctxsize), \ ++ .cra_module = THIS_MODULE, \ ++ }, \ ++} }; \ ++ \ ++static struct simd_aead_alg *aes_gcm_simdalgs_##suffix[2] \ ++ ++/* aes_gcm_algs_aesni */ ++DEFINE_GCM_ALGS(aesni, /* no flags */ 0, ++ "generic-gcm-aesni", "rfc4106-gcm-aesni", ++ AES_GCM_KEY_AESNI_SIZE, 400); ++ ++/* aes_gcm_algs_aesni_avx */ ++DEFINE_GCM_ALGS(aesni_avx, FLAG_AVX, ++ "generic-gcm-aesni-avx", "rfc4106-gcm-aesni-avx", ++ AES_GCM_KEY_AESNI_SIZE, 500); ++ ++#if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ) ++/* aes_gcm_algs_vaes_avx10_256 */ ++DEFINE_GCM_ALGS(vaes_avx10_256, FLAG_AVX10_256, ++ "generic-gcm-vaes-avx10_256", "rfc4106-gcm-vaes-avx10_256", ++ AES_GCM_KEY_AVX10_SIZE, 700); ++ ++/* aes_gcm_algs_vaes_avx10_512 */ ++DEFINE_GCM_ALGS(vaes_avx10_512, FLAG_AVX10_512, ++ "generic-gcm-vaes-avx10_512", "rfc4106-gcm-vaes-avx10_512", ++ AES_GCM_KEY_AVX10_SIZE, 800); ++#endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */ ++ + /* + * This is a list of CPU models that are known to suffer from downclocking when +- * zmm registers (512-bit vectors) are used. On these CPUs, the AES-XTS +- * implementation with zmm registers won't be used by default. An +- * implementation with ymm registers (256-bit vectors) will be used instead. ++ * zmm registers (512-bit vectors) are used. On these CPUs, the AES mode ++ * implementations with zmm registers won't be used by default. Implementations ++ * with ymm registers (256-bit vectors) will be used by default instead. + */ + static const struct x86_cpu_id zmm_exclusion_list[] = { + X86_MATCH_VFM(INTEL_SKYLAKE_X, 0), +@@ -1236,7 +1559,7 @@ static const struct x86_cpu_id zmm_exclusion_list[] = { + {}, + }; + +-static int __init register_xts_algs(void) ++static int __init register_avx_algs(void) + { + int err; + +@@ -1246,6 +1569,11 @@ static int __init register_xts_algs(void) + &aes_xts_simdalg_aesni_avx); + if (err) + return err; ++ err = simd_register_aeads_compat(aes_gcm_algs_aesni_avx, ++ ARRAY_SIZE(aes_gcm_algs_aesni_avx), ++ aes_gcm_simdalgs_aesni_avx); ++ if (err) ++ return err; + #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ) + if (!boot_cpu_has(X86_FEATURE_AVX2) || + !boot_cpu_has(X86_FEATURE_VAES) || +@@ -1269,23 +1597,42 @@ static int __init register_xts_algs(void) + &aes_xts_simdalg_vaes_avx10_256); + if (err) + return err; ++ err = simd_register_aeads_compat(aes_gcm_algs_vaes_avx10_256, ++ ARRAY_SIZE(aes_gcm_algs_vaes_avx10_256), ++ aes_gcm_simdalgs_vaes_avx10_256); ++ if (err) ++ return err; ++ ++ if (x86_match_cpu(zmm_exclusion_list)) { ++ int i; + +- if (x86_match_cpu(zmm_exclusion_list)) + aes_xts_alg_vaes_avx10_512.base.cra_priority = 1; ++ for (i = 0; i < ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512); i++) ++ aes_gcm_algs_vaes_avx10_512[i].base.cra_priority = 1; ++ } + + err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_512, 1, + &aes_xts_simdalg_vaes_avx10_512); + if (err) + return err; ++ err = simd_register_aeads_compat(aes_gcm_algs_vaes_avx10_512, ++ ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512), ++ aes_gcm_simdalgs_vaes_avx10_512); ++ if (err) ++ return err; + #endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */ + return 0; + } + +-static void unregister_xts_algs(void) ++static void unregister_avx_algs(void) + { + if (aes_xts_simdalg_aesni_avx) + simd_unregister_skciphers(&aes_xts_alg_aesni_avx, 1, + &aes_xts_simdalg_aesni_avx); ++ if (aes_gcm_simdalgs_aesni_avx[0]) ++ simd_unregister_aeads(aes_gcm_algs_aesni_avx, ++ ARRAY_SIZE(aes_gcm_algs_aesni_avx), ++ aes_gcm_simdalgs_aesni_avx); + #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ) + if (aes_xts_simdalg_vaes_avx2) + simd_unregister_skciphers(&aes_xts_alg_vaes_avx2, 1, +@@ -1293,106 +1640,33 @@ static void unregister_xts_algs(void) + if (aes_xts_simdalg_vaes_avx10_256) + simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_256, 1, + &aes_xts_simdalg_vaes_avx10_256); ++ if (aes_gcm_simdalgs_vaes_avx10_256[0]) ++ simd_unregister_aeads(aes_gcm_algs_vaes_avx10_256, ++ ARRAY_SIZE(aes_gcm_algs_vaes_avx10_256), ++ aes_gcm_simdalgs_vaes_avx10_256); + if (aes_xts_simdalg_vaes_avx10_512) + simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_512, 1, + &aes_xts_simdalg_vaes_avx10_512); ++ if (aes_gcm_simdalgs_vaes_avx10_512[0]) ++ simd_unregister_aeads(aes_gcm_algs_vaes_avx10_512, ++ ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512), ++ aes_gcm_simdalgs_vaes_avx10_512); + #endif + } + #else /* CONFIG_X86_64 */ +-static int __init register_xts_algs(void) ++static struct aead_alg aes_gcm_algs_aesni[0]; ++static struct simd_aead_alg *aes_gcm_simdalgs_aesni[0]; ++ ++static int __init register_avx_algs(void) + { + return 0; + } + +-static void unregister_xts_algs(void) ++static void unregister_avx_algs(void) + { + } + #endif /* !CONFIG_X86_64 */ + +-#ifdef CONFIG_X86_64 +-static int generic_gcmaes_set_key(struct crypto_aead *aead, const u8 *key, +- unsigned int key_len) +-{ +- struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(aead); +- +- return aes_set_key_common(&ctx->aes_key_expanded, key, key_len) ?: +- aes_gcm_derive_hash_subkey(&ctx->aes_key_expanded, +- ctx->hash_subkey); +-} +- +-static int generic_gcmaes_encrypt(struct aead_request *req) +-{ +- struct crypto_aead *tfm = crypto_aead_reqtfm(req); +- struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm); +- void *aes_ctx = &(ctx->aes_key_expanded); +- u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8); +- u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN); +- __be32 counter = cpu_to_be32(1); +- +- memcpy(iv, req->iv, 12); +- *((__be32 *)(iv+12)) = counter; +- +- return gcmaes_encrypt(req, req->assoclen, ctx->hash_subkey, iv, +- aes_ctx); +-} +- +-static int generic_gcmaes_decrypt(struct aead_request *req) +-{ +- __be32 counter = cpu_to_be32(1); +- struct crypto_aead *tfm = crypto_aead_reqtfm(req); +- struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm); +- void *aes_ctx = &(ctx->aes_key_expanded); +- u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8); +- u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN); +- +- memcpy(iv, req->iv, 12); +- *((__be32 *)(iv+12)) = counter; +- +- return gcmaes_decrypt(req, req->assoclen, ctx->hash_subkey, iv, +- aes_ctx); +-} +- +-static struct aead_alg aesni_aeads[] = { { +- .setkey = common_rfc4106_set_key, +- .setauthsize = common_rfc4106_set_authsize, +- .encrypt = helper_rfc4106_encrypt, +- .decrypt = helper_rfc4106_decrypt, +- .ivsize = GCM_RFC4106_IV_SIZE, +- .maxauthsize = 16, +- .base = { +- .cra_name = "__rfc4106(gcm(aes))", +- .cra_driver_name = "__rfc4106-gcm-aesni", +- .cra_priority = 400, +- .cra_flags = CRYPTO_ALG_INTERNAL, +- .cra_blocksize = 1, +- .cra_ctxsize = sizeof(struct aesni_rfc4106_gcm_ctx), +- .cra_alignmask = 0, +- .cra_module = THIS_MODULE, +- }, +-}, { +- .setkey = generic_gcmaes_set_key, +- .setauthsize = generic_gcmaes_set_authsize, +- .encrypt = generic_gcmaes_encrypt, +- .decrypt = generic_gcmaes_decrypt, +- .ivsize = GCM_AES_IV_SIZE, +- .maxauthsize = 16, +- .base = { +- .cra_name = "__gcm(aes)", +- .cra_driver_name = "__generic-gcm-aesni", +- .cra_priority = 400, +- .cra_flags = CRYPTO_ALG_INTERNAL, +- .cra_blocksize = 1, +- .cra_ctxsize = sizeof(struct generic_gcmaes_ctx), +- .cra_alignmask = 0, +- .cra_module = THIS_MODULE, +- }, +-} }; +-#else +-static struct aead_alg aesni_aeads[0]; +-#endif +- +-static struct simd_aead_alg *aesni_simd_aeads[ARRAY_SIZE(aesni_aeads)]; +- + static const struct x86_cpu_id aesni_cpu_id[] = { + X86_MATCH_FEATURE(X86_FEATURE_AES, NULL), + {} +@@ -1406,17 +1680,6 @@ static int __init aesni_init(void) + if (!x86_match_cpu(aesni_cpu_id)) + return -ENODEV; + #ifdef CONFIG_X86_64 +- if (boot_cpu_has(X86_FEATURE_AVX2)) { +- pr_info("AVX2 version of gcm_enc/dec engaged.\n"); +- static_branch_enable(&gcm_use_avx); +- static_branch_enable(&gcm_use_avx2); +- } else +- if (boot_cpu_has(X86_FEATURE_AVX)) { +- pr_info("AVX version of gcm_enc/dec engaged.\n"); +- static_branch_enable(&gcm_use_avx); +- } else { +- pr_info("SSE version of gcm_enc/dec engaged.\n"); +- } + if (boot_cpu_has(X86_FEATURE_AVX)) { + /* optimize performance of ctr mode encryption transform */ + static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm); +@@ -1434,8 +1697,9 @@ static int __init aesni_init(void) + if (err) + goto unregister_cipher; + +- err = simd_register_aeads_compat(aesni_aeads, ARRAY_SIZE(aesni_aeads), +- aesni_simd_aeads); ++ err = simd_register_aeads_compat(aes_gcm_algs_aesni, ++ ARRAY_SIZE(aes_gcm_algs_aesni), ++ aes_gcm_simdalgs_aesni); + if (err) + goto unregister_skciphers; + +@@ -1447,22 +1711,22 @@ static int __init aesni_init(void) + goto unregister_aeads; + #endif /* CONFIG_X86_64 */ + +- err = register_xts_algs(); ++ err = register_avx_algs(); + if (err) +- goto unregister_xts; ++ goto unregister_avx; + + return 0; + +-unregister_xts: +- unregister_xts_algs(); ++unregister_avx: ++ unregister_avx_algs(); + #ifdef CONFIG_X86_64 + if (aesni_simd_xctr) + simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr); + unregister_aeads: + #endif /* CONFIG_X86_64 */ +- simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads), +- aesni_simd_aeads); +- ++ simd_unregister_aeads(aes_gcm_algs_aesni, ++ ARRAY_SIZE(aes_gcm_algs_aesni), ++ aes_gcm_simdalgs_aesni); + unregister_skciphers: + simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers), + aesni_simd_skciphers); +@@ -1473,8 +1737,9 @@ static int __init aesni_init(void) + + static void __exit aesni_exit(void) + { +- simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads), +- aesni_simd_aeads); ++ simd_unregister_aeads(aes_gcm_algs_aesni, ++ ARRAY_SIZE(aes_gcm_algs_aesni), ++ aes_gcm_simdalgs_aesni); + simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers), + aesni_simd_skciphers); + crypto_unregister_alg(&aesni_cipher_alg); +@@ -1482,7 +1747,7 @@ static void __exit aesni_exit(void) + if (boot_cpu_has(X86_FEATURE_AVX)) + simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr); + #endif /* CONFIG_X86_64 */ +- unregister_xts_algs(); ++ unregister_avx_algs(); + } + + late_initcall(aesni_init); +-- +2.46.0 + +From 1de4c03fd12b0064af9db641bac0080f222cca7e Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:19:23 +0200 +Subject: [PATCH 06/12] fixes + +Signed-off-by: Peter Jung +--- + arch/Kconfig | 4 ++-- + drivers/gpu/drm/nouveau/nouveau_bo.c | 2 +- + drivers/gpu/drm/nouveau/nouveau_chan.c | 2 +- + drivers/gpu/drm/nouveau/nouveau_dmem.c | 2 +- + drivers/gpu/drm/nouveau/nouveau_fence.c | 30 +------------------------ + drivers/gpu/drm/nouveau/nouveau_fence.h | 2 +- + drivers/gpu/drm/nouveau/nouveau_gem.c | 2 +- + fs/btrfs/extent_map.c | 22 +++++------------- + fs/btrfs/super.c | 10 +++++++++ + 9 files changed, 24 insertions(+), 52 deletions(-) + +diff --git a/arch/Kconfig b/arch/Kconfig +index 975dd22a2dbd..de69b8f5b5be 100644 +--- a/arch/Kconfig ++++ b/arch/Kconfig +@@ -1050,7 +1050,7 @@ config ARCH_MMAP_RND_BITS + int "Number of bits to use for ASLR of mmap base address" if EXPERT + range ARCH_MMAP_RND_BITS_MIN ARCH_MMAP_RND_BITS_MAX + default ARCH_MMAP_RND_BITS_DEFAULT if ARCH_MMAP_RND_BITS_DEFAULT +- default ARCH_MMAP_RND_BITS_MIN ++ default ARCH_MMAP_RND_BITS_MAX + depends on HAVE_ARCH_MMAP_RND_BITS + help + This value can be used to select the number of bits to use to +@@ -1084,7 +1084,7 @@ config ARCH_MMAP_RND_COMPAT_BITS + int "Number of bits to use for ASLR of mmap base address for compatible applications" if EXPERT + range ARCH_MMAP_RND_COMPAT_BITS_MIN ARCH_MMAP_RND_COMPAT_BITS_MAX + default ARCH_MMAP_RND_COMPAT_BITS_DEFAULT if ARCH_MMAP_RND_COMPAT_BITS_DEFAULT +- default ARCH_MMAP_RND_COMPAT_BITS_MIN ++ default ARCH_MMAP_RND_COMPAT_BITS_MAX + depends on HAVE_ARCH_MMAP_RND_COMPAT_BITS + help + This value can be used to select the number of bits to use to +diff --git a/drivers/gpu/drm/nouveau/nouveau_bo.c b/drivers/gpu/drm/nouveau/nouveau_bo.c +index 70fb003a6666..0712d0b15170 100644 +--- a/drivers/gpu/drm/nouveau/nouveau_bo.c ++++ b/drivers/gpu/drm/nouveau/nouveau_bo.c +@@ -898,7 +898,7 @@ nouveau_bo_move_m2mf(struct ttm_buffer_object *bo, int evict, + * Without this the operation can timeout and we'll fallback to a + * software copy, which might take several minutes to finish. + */ +- nouveau_fence_wait(fence, false, false); ++ nouveau_fence_wait(fence, false); + ret = ttm_bo_move_accel_cleanup(bo, &fence->base, evict, false, + new_reg); + nouveau_fence_unref(&fence); +diff --git a/drivers/gpu/drm/nouveau/nouveau_chan.c b/drivers/gpu/drm/nouveau/nouveau_chan.c +index 7c97b2886807..66fca95c10c7 100644 +--- a/drivers/gpu/drm/nouveau/nouveau_chan.c ++++ b/drivers/gpu/drm/nouveau/nouveau_chan.c +@@ -72,7 +72,7 @@ nouveau_channel_idle(struct nouveau_channel *chan) + + ret = nouveau_fence_new(&fence, chan); + if (!ret) { +- ret = nouveau_fence_wait(fence, false, false); ++ ret = nouveau_fence_wait(fence, false); + nouveau_fence_unref(&fence); + } + +diff --git a/drivers/gpu/drm/nouveau/nouveau_dmem.c b/drivers/gpu/drm/nouveau/nouveau_dmem.c +index 6fb65b01d778..6719353e2e13 100644 +--- a/drivers/gpu/drm/nouveau/nouveau_dmem.c ++++ b/drivers/gpu/drm/nouveau/nouveau_dmem.c +@@ -128,7 +128,7 @@ static void nouveau_dmem_page_free(struct page *page) + static void nouveau_dmem_fence_done(struct nouveau_fence **fence) + { + if (fence) { +- nouveau_fence_wait(*fence, true, false); ++ nouveau_fence_wait(*fence, false); + nouveau_fence_unref(fence); + } else { + /* +diff --git a/drivers/gpu/drm/nouveau/nouveau_fence.c b/drivers/gpu/drm/nouveau/nouveau_fence.c +index 93f08f9479d8..ba469767a20f 100644 +--- a/drivers/gpu/drm/nouveau/nouveau_fence.c ++++ b/drivers/gpu/drm/nouveau/nouveau_fence.c +@@ -311,39 +311,11 @@ nouveau_fence_wait_legacy(struct dma_fence *f, bool intr, long wait) + return timeout - t; + } + +-static int +-nouveau_fence_wait_busy(struct nouveau_fence *fence, bool intr) +-{ +- int ret = 0; +- +- while (!nouveau_fence_done(fence)) { +- if (time_after_eq(jiffies, fence->timeout)) { +- ret = -EBUSY; +- break; +- } +- +- __set_current_state(intr ? +- TASK_INTERRUPTIBLE : +- TASK_UNINTERRUPTIBLE); +- +- if (intr && signal_pending(current)) { +- ret = -ERESTARTSYS; +- break; +- } +- } +- +- __set_current_state(TASK_RUNNING); +- return ret; +-} +- + int +-nouveau_fence_wait(struct nouveau_fence *fence, bool lazy, bool intr) ++nouveau_fence_wait(struct nouveau_fence *fence, bool intr) + { + long ret; + +- if (!lazy) +- return nouveau_fence_wait_busy(fence, intr); +- + ret = dma_fence_wait_timeout(&fence->base, intr, 15 * HZ); + if (ret < 0) + return ret; +diff --git a/drivers/gpu/drm/nouveau/nouveau_fence.h b/drivers/gpu/drm/nouveau/nouveau_fence.h +index 8bc065acfe35..1b63197b744a 100644 +--- a/drivers/gpu/drm/nouveau/nouveau_fence.h ++++ b/drivers/gpu/drm/nouveau/nouveau_fence.h +@@ -23,7 +23,7 @@ void nouveau_fence_unref(struct nouveau_fence **); + + int nouveau_fence_emit(struct nouveau_fence *); + bool nouveau_fence_done(struct nouveau_fence *); +-int nouveau_fence_wait(struct nouveau_fence *, bool lazy, bool intr); ++int nouveau_fence_wait(struct nouveau_fence *, bool intr); + int nouveau_fence_sync(struct nouveau_bo *, struct nouveau_channel *, bool exclusive, bool intr); + + struct nouveau_fence_chan { +diff --git a/drivers/gpu/drm/nouveau/nouveau_gem.c b/drivers/gpu/drm/nouveau/nouveau_gem.c +index 5a887d67dc0e..2e535caa7d6e 100644 +--- a/drivers/gpu/drm/nouveau/nouveau_gem.c ++++ b/drivers/gpu/drm/nouveau/nouveau_gem.c +@@ -928,7 +928,7 @@ nouveau_gem_ioctl_pushbuf(struct drm_device *dev, void *data, + } + + if (sync) { +- if (!(ret = nouveau_fence_wait(fence, false, false))) { ++ if (!(ret = nouveau_fence_wait(fence, false))) { + if ((ret = dma_fence_get_status(&fence->base)) == 1) + ret = 0; + } +diff --git a/fs/btrfs/extent_map.c b/fs/btrfs/extent_map.c +index b4c9a6aa118c..6853f043c2c1 100644 +--- a/fs/btrfs/extent_map.c ++++ b/fs/btrfs/extent_map.c +@@ -1065,8 +1065,7 @@ static long btrfs_scan_inode(struct btrfs_inode *inode, struct btrfs_em_shrink_c + return 0; + + /* +- * We want to be fast because we can be called from any path trying to +- * allocate memory, so if the lock is busy we don't want to spend time ++ * We want to be fast so if the lock is busy we don't want to spend time + * waiting for it - either some task is about to do IO for the inode or + * we may have another task shrinking extent maps, here in this code, so + * skip this inode. +@@ -1109,9 +1108,7 @@ static long btrfs_scan_inode(struct btrfs_inode *inode, struct btrfs_em_shrink_c + /* + * Stop if we need to reschedule or there's contention on the + * lock. This is to avoid slowing other tasks trying to take the +- * lock and because the shrinker might be called during a memory +- * allocation path and we want to avoid taking a very long time +- * and slowing down all sorts of tasks. ++ * lock. + */ + if (need_resched() || rwlock_needbreak(&tree->lock)) + break; +@@ -1139,12 +1136,7 @@ static long btrfs_scan_root(struct btrfs_root *root, struct btrfs_em_shrink_ctx + if (ctx->scanned >= ctx->nr_to_scan) + break; + +- /* +- * We may be called from memory allocation paths, so we don't +- * want to take too much time and slowdown tasks. +- */ +- if (need_resched()) +- break; ++ cond_resched(); + + inode = btrfs_find_first_inode(root, min_ino); + } +@@ -1202,14 +1194,12 @@ long btrfs_free_extent_maps(struct btrfs_fs_info *fs_info, long nr_to_scan) + ctx.last_ino); + } + +- /* +- * We may be called from memory allocation paths, so we don't want to +- * take too much time and slowdown tasks, so stop if we need reschedule. +- */ +- while (ctx.scanned < ctx.nr_to_scan && !need_resched()) { ++ while (ctx.scanned < ctx.nr_to_scan) { + struct btrfs_root *root; + unsigned long count; + ++ cond_resched(); ++ + spin_lock(&fs_info->fs_roots_radix_lock); + count = radix_tree_gang_lookup(&fs_info->fs_roots_radix, + (void **)&root, +diff --git a/fs/btrfs/super.c b/fs/btrfs/super.c +index f05cce7c8b8d..11faf5e983ea 100644 +--- a/fs/btrfs/super.c ++++ b/fs/btrfs/super.c +@@ -28,6 +28,7 @@ + #include + #include + #include ++#include + #include "messages.h" + #include "delayed-inode.h" + #include "ctree.h" +@@ -2394,6 +2395,15 @@ static long btrfs_free_cached_objects(struct super_block *sb, struct shrink_cont + const long nr_to_scan = min_t(unsigned long, LONG_MAX, sc->nr_to_scan); + struct btrfs_fs_info *fs_info = btrfs_sb(sb); + ++ /* ++ * We may be called from any task trying to allocate memory and we don't ++ * want to slow it down with scanning and dropping extent maps. It would ++ * also cause heavy lock contention if many tasks concurrently enter ++ * here. Therefore only allow kswapd tasks to scan and drop extent maps. ++ */ ++ if (!current_is_kswapd()) ++ return 0; ++ + return btrfs_free_extent_maps(fs_info, nr_to_scan); + } + +-- +2.46.0 + +From 6b5c02d298d86e3d59dbc0b01df47ffcfa74f79f Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:19:34 +0200 +Subject: [PATCH 07/12] intel-pstate + +Signed-off-by: Peter Jung +--- + arch/x86/include/asm/topology.h | 19 ++- + arch/x86/kernel/cpu/aperfmperf.c | 87 +++++++++++- + arch/x86/kernel/itmt.c | 12 +- + arch/x86/kernel/smpboot.c | 10 +- + drivers/cpufreq/intel_pstate.c | 220 ++++++++++++++++++++++++++++++- + 5 files changed, 332 insertions(+), 16 deletions(-) + +diff --git a/arch/x86/include/asm/topology.h b/arch/x86/include/asm/topology.h +index abe3a8f22cbd..e5b203fe7956 100644 +--- a/arch/x86/include/asm/topology.h ++++ b/arch/x86/include/asm/topology.h +@@ -235,8 +235,6 @@ struct pci_bus; + int x86_pci_root_bus_node(int bus); + void x86_pci_root_bus_resources(int bus, struct list_head *resources); + +-extern bool x86_topology_update; +- + #ifdef CONFIG_SCHED_MC_PRIO + #include + +@@ -282,11 +280,28 @@ static inline long arch_scale_freq_capacity(int cpu) + } + #define arch_scale_freq_capacity arch_scale_freq_capacity + ++bool arch_enable_hybrid_capacity_scale(void); ++void arch_set_cpu_capacity(int cpu, unsigned long cap, unsigned long base_cap, ++ unsigned long max_freq, unsigned long base_freq); ++ ++unsigned long arch_scale_cpu_capacity(int cpu); ++#define arch_scale_cpu_capacity arch_scale_cpu_capacity ++ + extern void arch_set_max_freq_ratio(bool turbo_disabled); + extern void freq_invariance_set_perf_ratio(u64 ratio, bool turbo_disabled); ++ ++void arch_rebuild_sched_domains(void); + #else ++static inline bool arch_enable_hybrid_capacity_scale(void) { return false; } ++static inline void arch_set_cpu_capacity(int cpu, unsigned long cap, ++ unsigned long base_cap, ++ unsigned long max_freq, ++ unsigned long base_freq) { } ++ + static inline void arch_set_max_freq_ratio(bool turbo_disabled) { } + static inline void freq_invariance_set_perf_ratio(u64 ratio, bool turbo_disabled) { } ++ ++static inline void arch_rebuild_sched_domains(void) { } + #endif + + extern void arch_scale_freq_tick(void); +diff --git a/arch/x86/kernel/cpu/aperfmperf.c b/arch/x86/kernel/cpu/aperfmperf.c +index b3fa61d45352..6ff86c02fe63 100644 +--- a/arch/x86/kernel/cpu/aperfmperf.c ++++ b/arch/x86/kernel/cpu/aperfmperf.c +@@ -347,9 +347,89 @@ static DECLARE_WORK(disable_freq_invariance_work, + DEFINE_PER_CPU(unsigned long, arch_freq_scale) = SCHED_CAPACITY_SCALE; + EXPORT_PER_CPU_SYMBOL_GPL(arch_freq_scale); + ++static DEFINE_STATIC_KEY_FALSE(arch_hybrid_cap_scale_key); ++ ++struct arch_hybrid_cpu_scale { ++ unsigned long capacity; ++ unsigned long freq_ratio; ++}; ++ ++static struct arch_hybrid_cpu_scale __percpu *arch_cpu_scale; ++ ++/** ++ * arch_enable_hybrid_capacity_scale - Enable hybrid CPU capacity scaling ++ * ++ * Allocate memory for per-CPU data used by hybrid CPU capacity scaling, ++ * initialize it and set the static key controlling its code paths. ++ * ++ * Must be called before arch_set_cpu_capacity(). ++ */ ++bool arch_enable_hybrid_capacity_scale(void) ++{ ++ int cpu; ++ ++ if (static_branch_unlikely(&arch_hybrid_cap_scale_key)) { ++ WARN_ONCE(1, "Hybrid CPU capacity scaling already enabled"); ++ return true; ++ } ++ arch_cpu_scale = alloc_percpu(struct arch_hybrid_cpu_scale); ++ if (!arch_cpu_scale) ++ return false; ++ ++ for_each_possible_cpu(cpu) { ++ per_cpu_ptr(arch_cpu_scale, cpu)->capacity = SCHED_CAPACITY_SCALE; ++ per_cpu_ptr(arch_cpu_scale, cpu)->freq_ratio = arch_max_freq_ratio; ++ } ++ ++ static_branch_enable(&arch_hybrid_cap_scale_key); ++ ++ pr_info("Hybrid CPU capacity scaling enabled\n"); ++ ++ return true; ++} ++ ++/** ++ * arch_set_cpu_capacity - Set scale-invariance parameters for a CPU ++ * @cpu: Target CPU. ++ * @cap: Capacity of @cpu, relative to @base_cap, at its maximum frequency. ++ * @base_cap: System-wide maximum CPU capacity. ++ * @max_freq: Frequency of @cpu corresponding to @cap. ++ * @base_freq: Frequency of @cpu at which MPERF counts. ++ * ++ * The units in which @cap and @base_cap are expressed do not matter, so long ++ * as they are consistent, because the former is effectively divided by the ++ * latter. Analogously for @max_freq and @base_freq. ++ * ++ * After calling this function for all CPUs, call arch_rebuild_sched_domains() ++ * to let the scheduler know that capacity-aware scheduling can be used going ++ * forward. ++ */ ++void arch_set_cpu_capacity(int cpu, unsigned long cap, unsigned long base_cap, ++ unsigned long max_freq, unsigned long base_freq) ++{ ++ if (static_branch_likely(&arch_hybrid_cap_scale_key)) { ++ WRITE_ONCE(per_cpu_ptr(arch_cpu_scale, cpu)->capacity, ++ div_u64(cap << SCHED_CAPACITY_SHIFT, base_cap)); ++ WRITE_ONCE(per_cpu_ptr(arch_cpu_scale, cpu)->freq_ratio, ++ div_u64(max_freq << SCHED_CAPACITY_SHIFT, base_freq)); ++ } else { ++ WARN_ONCE(1, "Hybrid CPU capacity scaling not enabled"); ++ } ++} ++ ++unsigned long arch_scale_cpu_capacity(int cpu) ++{ ++ if (static_branch_unlikely(&arch_hybrid_cap_scale_key)) ++ return READ_ONCE(per_cpu_ptr(arch_cpu_scale, cpu)->capacity); ++ ++ return SCHED_CAPACITY_SCALE; ++} ++EXPORT_SYMBOL_GPL(arch_scale_cpu_capacity); ++ + static void scale_freq_tick(u64 acnt, u64 mcnt) + { + u64 freq_scale; ++ u64 freq_ratio; + + if (!arch_scale_freq_invariant()) + return; +@@ -357,7 +437,12 @@ static void scale_freq_tick(u64 acnt, u64 mcnt) + if (check_shl_overflow(acnt, 2*SCHED_CAPACITY_SHIFT, &acnt)) + goto error; + +- if (check_mul_overflow(mcnt, arch_max_freq_ratio, &mcnt) || !mcnt) ++ if (static_branch_unlikely(&arch_hybrid_cap_scale_key)) ++ freq_ratio = READ_ONCE(this_cpu_ptr(arch_cpu_scale)->freq_ratio); ++ else ++ freq_ratio = arch_max_freq_ratio; ++ ++ if (check_mul_overflow(mcnt, freq_ratio, &mcnt) || !mcnt) + goto error; + + freq_scale = div64_u64(acnt, mcnt); +diff --git a/arch/x86/kernel/itmt.c b/arch/x86/kernel/itmt.c +index 9a7c03d47861..af2f60c094a8 100644 +--- a/arch/x86/kernel/itmt.c ++++ b/arch/x86/kernel/itmt.c +@@ -54,10 +54,8 @@ static int sched_itmt_update_handler(struct ctl_table *table, int write, + old_sysctl = sysctl_sched_itmt_enabled; + ret = proc_dointvec_minmax(table, write, buffer, lenp, ppos); + +- if (!ret && write && old_sysctl != sysctl_sched_itmt_enabled) { +- x86_topology_update = true; +- rebuild_sched_domains(); +- } ++ if (!ret && write && old_sysctl != sysctl_sched_itmt_enabled) ++ arch_rebuild_sched_domains(); + + mutex_unlock(&itmt_update_mutex); + +@@ -114,8 +112,7 @@ int sched_set_itmt_support(void) + + sysctl_sched_itmt_enabled = 1; + +- x86_topology_update = true; +- rebuild_sched_domains(); ++ arch_rebuild_sched_domains(); + + mutex_unlock(&itmt_update_mutex); + +@@ -150,8 +147,7 @@ void sched_clear_itmt_support(void) + if (sysctl_sched_itmt_enabled) { + /* disable sched_itmt if we are no longer ITMT capable */ + sysctl_sched_itmt_enabled = 0; +- x86_topology_update = true; +- rebuild_sched_domains(); ++ arch_rebuild_sched_domains(); + } + + mutex_unlock(&itmt_update_mutex); +diff --git a/arch/x86/kernel/smpboot.c b/arch/x86/kernel/smpboot.c +index 0c35207320cb..90a6fb54a128 100644 +--- a/arch/x86/kernel/smpboot.c ++++ b/arch/x86/kernel/smpboot.c +@@ -39,6 +39,7 @@ + + #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + ++#include + #include + #include + #include +@@ -125,7 +126,7 @@ static DEFINE_PER_CPU_ALIGNED(struct mwait_cpu_dead, mwait_cpu_dead); + int __read_mostly __max_smt_threads = 1; + + /* Flag to indicate if a complete sched domain rebuild is required */ +-bool x86_topology_update; ++static bool x86_topology_update; + + int arch_update_cpu_topology(void) + { +@@ -135,6 +136,13 @@ int arch_update_cpu_topology(void) + return retval; + } + ++#ifdef CONFIG_X86_64 ++void arch_rebuild_sched_domains(void) { ++ x86_topology_update = true; ++ rebuild_sched_domains(); ++} ++#endif ++ + static unsigned int smpboot_warm_reset_vector_count; + + static inline void smpboot_setup_warm_reset_vector(unsigned long start_eip) +diff --git a/drivers/cpufreq/intel_pstate.c b/drivers/cpufreq/intel_pstate.c +index 1035c074f36a..2ab495e1ba36 100644 +--- a/drivers/cpufreq/intel_pstate.c ++++ b/drivers/cpufreq/intel_pstate.c +@@ -16,6 +16,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -215,6 +216,7 @@ struct global_params { + * @hwp_req_cached: Cached value of the last HWP Request MSR + * @hwp_cap_cached: Cached value of the last HWP Capabilities MSR + * @last_io_update: Last time when IO wake flag was set ++ * @capacity_perf: Highest perf used for scale invariance + * @sched_flags: Store scheduler flags for possible cross CPU update + * @hwp_boost_min: Last HWP boosted min performance + * @suspended: Whether or not the driver has been suspended. +@@ -253,6 +255,7 @@ struct cpudata { + u64 hwp_req_cached; + u64 hwp_cap_cached; + u64 last_io_update; ++ unsigned int capacity_perf; + unsigned int sched_flags; + u32 hwp_boost_min; + bool suspended; +@@ -295,6 +298,7 @@ static int hwp_mode_bdw __ro_after_init; + static bool per_cpu_limits __ro_after_init; + static bool hwp_forced __ro_after_init; + static bool hwp_boost __read_mostly; ++static bool hwp_is_hybrid; + + static struct cpufreq_driver *intel_pstate_driver __read_mostly; + +@@ -933,6 +937,111 @@ static struct freq_attr *hwp_cpufreq_attrs[] = { + NULL, + }; + ++static struct cpudata *hybrid_max_perf_cpu __read_mostly; ++/* ++ * Protects hybrid_max_perf_cpu, the capacity_perf fields in struct cpudata, ++ * and the x86 arch scale-invariance information from concurrent updates. ++ */ ++static DEFINE_MUTEX(hybrid_capacity_lock); ++ ++static void hybrid_set_cpu_capacity(struct cpudata *cpu) ++{ ++ arch_set_cpu_capacity(cpu->cpu, cpu->capacity_perf, ++ hybrid_max_perf_cpu->capacity_perf, ++ cpu->capacity_perf, ++ cpu->pstate.max_pstate_physical); ++ ++ pr_debug("CPU%d: perf = %u, max. perf = %u, base perf = %d\n", cpu->cpu, ++ cpu->capacity_perf, hybrid_max_perf_cpu->capacity_perf, ++ cpu->pstate.max_pstate_physical); ++} ++ ++static void hybrid_clear_cpu_capacity(unsigned int cpunum) ++{ ++ arch_set_cpu_capacity(cpunum, 1, 1, 1, 1); ++} ++ ++static void hybrid_get_capacity_perf(struct cpudata *cpu) ++{ ++ if (READ_ONCE(global.no_turbo)) { ++ cpu->capacity_perf = cpu->pstate.max_pstate_physical; ++ return; ++ } ++ ++ cpu->capacity_perf = HWP_HIGHEST_PERF(READ_ONCE(cpu->hwp_cap_cached)); ++} ++ ++static void hybrid_set_capacity_of_cpus(void) ++{ ++ int cpunum; ++ ++ for_each_online_cpu(cpunum) { ++ struct cpudata *cpu = all_cpu_data[cpunum]; ++ ++ if (cpu) ++ hybrid_set_cpu_capacity(cpu); ++ } ++} ++ ++static void hybrid_update_cpu_scaling(void) ++{ ++ struct cpudata *max_perf_cpu = NULL; ++ unsigned int max_cap_perf = 0; ++ int cpunum; ++ ++ for_each_online_cpu(cpunum) { ++ struct cpudata *cpu = all_cpu_data[cpunum]; ++ ++ if (!cpu) ++ continue; ++ ++ /* ++ * During initialization, CPU performance at full capacity needs ++ * to be determined. ++ */ ++ if (!hybrid_max_perf_cpu) ++ hybrid_get_capacity_perf(cpu); ++ ++ /* ++ * If hybrid_max_perf_cpu is not NULL at this point, it is ++ * being replaced, so don't take it into account when looking ++ * for the new one. ++ */ ++ if (cpu == hybrid_max_perf_cpu) ++ continue; ++ ++ if (cpu->capacity_perf > max_cap_perf) { ++ max_cap_perf = cpu->capacity_perf; ++ max_perf_cpu = cpu; ++ } ++ } ++ ++ if (max_perf_cpu) { ++ hybrid_max_perf_cpu = max_perf_cpu; ++ hybrid_set_capacity_of_cpus(); ++ } else { ++ pr_info("Found no CPUs with nonzero maximum performance\n"); ++ /* Revert to the flat CPU capacity structure. */ ++ for_each_online_cpu(cpunum) ++ hybrid_clear_cpu_capacity(cpunum); ++ } ++} ++ ++static void __hybrid_init_cpu_scaling(void) ++{ ++ hybrid_max_perf_cpu = NULL; ++ hybrid_update_cpu_scaling(); ++} ++ ++static void hybrid_init_cpu_scaling(void) ++{ ++ mutex_lock(&hybrid_capacity_lock); ++ ++ __hybrid_init_cpu_scaling(); ++ ++ mutex_unlock(&hybrid_capacity_lock); ++} ++ + static void __intel_pstate_get_hwp_cap(struct cpudata *cpu) + { + u64 cap; +@@ -961,6 +1070,43 @@ static void intel_pstate_get_hwp_cap(struct cpudata *cpu) + } + } + ++static void hybrid_update_capacity(struct cpudata *cpu) ++{ ++ unsigned int max_cap_perf; ++ ++ mutex_lock(&hybrid_capacity_lock); ++ ++ if (!hybrid_max_perf_cpu) ++ goto unlock; ++ ++ /* ++ * The maximum performance of the CPU may have changed, but assume ++ * that the performance of the other CPUs has not changed. ++ */ ++ max_cap_perf = hybrid_max_perf_cpu->capacity_perf; ++ ++ intel_pstate_get_hwp_cap(cpu); ++ ++ hybrid_get_capacity_perf(cpu); ++ /* Should hybrid_max_perf_cpu be replaced by this CPU? */ ++ if (cpu->capacity_perf > max_cap_perf) { ++ hybrid_max_perf_cpu = cpu; ++ hybrid_set_capacity_of_cpus(); ++ goto unlock; ++ } ++ ++ /* If this CPU is hybrid_max_perf_cpu, should it be replaced? */ ++ if (cpu == hybrid_max_perf_cpu && cpu->capacity_perf < max_cap_perf) { ++ hybrid_update_cpu_scaling(); ++ goto unlock; ++ } ++ ++ hybrid_set_cpu_capacity(cpu); ++ ++unlock: ++ mutex_unlock(&hybrid_capacity_lock); ++} ++ + static void intel_pstate_hwp_set(unsigned int cpu) + { + struct cpudata *cpu_data = all_cpu_data[cpu]; +@@ -1069,6 +1215,22 @@ static void intel_pstate_hwp_offline(struct cpudata *cpu) + value |= HWP_ENERGY_PERF_PREFERENCE(HWP_EPP_POWERSAVE); + + wrmsrl_on_cpu(cpu->cpu, MSR_HWP_REQUEST, value); ++ ++ mutex_lock(&hybrid_capacity_lock); ++ ++ if (!hybrid_max_perf_cpu) { ++ mutex_unlock(&hybrid_capacity_lock); ++ ++ return; ++ } ++ ++ if (hybrid_max_perf_cpu == cpu) ++ hybrid_update_cpu_scaling(); ++ ++ mutex_unlock(&hybrid_capacity_lock); ++ ++ /* Reset the capacity of the CPU going offline to the initial value. */ ++ hybrid_clear_cpu_capacity(cpu->cpu); + } + + #define POWER_CTL_EE_ENABLE 1 +@@ -1164,21 +1326,46 @@ static void __intel_pstate_update_max_freq(struct cpudata *cpudata, + static void intel_pstate_update_limits(unsigned int cpu) + { + struct cpufreq_policy *policy = cpufreq_cpu_acquire(cpu); ++ struct cpudata *cpudata; + + if (!policy) + return; + +- __intel_pstate_update_max_freq(all_cpu_data[cpu], policy); ++ cpudata = all_cpu_data[cpu]; ++ ++ __intel_pstate_update_max_freq(cpudata, policy); ++ ++ /* Prevent the driver from being unregistered now. */ ++ mutex_lock(&intel_pstate_driver_lock); + + cpufreq_cpu_release(policy); ++ ++ hybrid_update_capacity(cpudata); ++ ++ mutex_unlock(&intel_pstate_driver_lock); + } + + static void intel_pstate_update_limits_for_all(void) + { + int cpu; + +- for_each_possible_cpu(cpu) +- intel_pstate_update_limits(cpu); ++ for_each_possible_cpu(cpu) { ++ struct cpufreq_policy *policy = cpufreq_cpu_acquire(cpu); ++ ++ if (!policy) ++ continue; ++ ++ __intel_pstate_update_max_freq(all_cpu_data[cpu], policy); ++ ++ cpufreq_cpu_release(policy); ++ } ++ ++ mutex_lock(&hybrid_capacity_lock); ++ ++ if (hybrid_max_perf_cpu) ++ __hybrid_init_cpu_scaling(); ++ ++ mutex_unlock(&hybrid_capacity_lock); + } + + /************************** sysfs begin ************************/ +@@ -1617,6 +1804,13 @@ static void intel_pstate_notify_work(struct work_struct *work) + __intel_pstate_update_max_freq(cpudata, policy); + + cpufreq_cpu_release(policy); ++ ++ /* ++ * The driver will not be unregistered while this function is ++ * running, so update the capacity without acquiring the driver ++ * lock. ++ */ ++ hybrid_update_capacity(cpudata); + } + + wrmsrl_on_cpu(cpudata->cpu, MSR_HWP_STATUS, 0); +@@ -2018,8 +2212,10 @@ static void intel_pstate_get_cpu_pstates(struct cpudata *cpu) + + if (pstate_funcs.get_cpu_scaling) { + cpu->pstate.scaling = pstate_funcs.get_cpu_scaling(cpu->cpu); +- if (cpu->pstate.scaling != perf_ctl_scaling) ++ if (cpu->pstate.scaling != perf_ctl_scaling) { + intel_pstate_hybrid_hwp_adjust(cpu); ++ hwp_is_hybrid = true; ++ } + } else { + cpu->pstate.scaling = perf_ctl_scaling; + } +@@ -2687,6 +2883,8 @@ static int intel_pstate_cpu_online(struct cpufreq_policy *policy) + */ + intel_pstate_hwp_reenable(cpu); + cpu->suspended = false; ++ ++ hybrid_update_capacity(cpu); + } + + return 0; +@@ -3129,6 +3327,20 @@ static int intel_pstate_register_driver(struct cpufreq_driver *driver) + + global.min_perf_pct = min_perf_pct_min(); + ++ /* ++ * On hybrid systems, use asym capacity instead of ITMT, but because ++ * the capacity of SMT threads is not deterministic even approximately, ++ * do not do that when SMT is in use. ++ */ ++ if (hwp_is_hybrid && !sched_smt_active() && ++ arch_enable_hybrid_capacity_scale()) { ++ sched_clear_itmt_support(); ++ ++ hybrid_init_cpu_scaling(); ++ ++ arch_rebuild_sched_domains(); ++ } ++ + return 0; + } + +-- +2.46.0 + +From b5c1d03092fedd1e5446aec3b8d16029834ef4b7 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:19:43 +0200 +Subject: [PATCH 08/12] ksm + +Signed-off-by: Peter Jung +--- + arch/alpha/kernel/syscalls/syscall.tbl | 3 + + arch/arm/tools/syscall.tbl | 3 + + arch/arm64/include/asm/unistd.h | 2 +- + arch/arm64/include/asm/unistd32.h | 6 + + arch/m68k/kernel/syscalls/syscall.tbl | 3 + + arch/microblaze/kernel/syscalls/syscall.tbl | 3 + + arch/mips/kernel/syscalls/syscall_n32.tbl | 3 + + arch/mips/kernel/syscalls/syscall_n64.tbl | 3 + + arch/mips/kernel/syscalls/syscall_o32.tbl | 3 + + arch/parisc/kernel/syscalls/syscall.tbl | 3 + + arch/powerpc/kernel/syscalls/syscall.tbl | 3 + + arch/s390/kernel/syscalls/syscall.tbl | 3 + + arch/sh/kernel/syscalls/syscall.tbl | 3 + + arch/sparc/kernel/syscalls/syscall.tbl | 3 + + arch/x86/entry/syscalls/syscall_32.tbl | 3 + + arch/x86/entry/syscalls/syscall_64.tbl | 3 + + arch/xtensa/kernel/syscalls/syscall.tbl | 3 + + include/linux/syscalls.h | 3 + + include/uapi/asm-generic/unistd.h | 11 +- + kernel/sys.c | 147 ++++++++++++++++++++ + kernel/sys_ni.c | 3 + + 21 files changed, 215 insertions(+), 2 deletions(-) + +diff --git a/arch/alpha/kernel/syscalls/syscall.tbl b/arch/alpha/kernel/syscalls/syscall.tbl +index 74720667fe09..e6a11f3c0a2e 100644 +--- a/arch/alpha/kernel/syscalls/syscall.tbl ++++ b/arch/alpha/kernel/syscalls/syscall.tbl +@@ -502,3 +502,6 @@ + 570 common lsm_set_self_attr sys_lsm_set_self_attr + 571 common lsm_list_modules sys_lsm_list_modules + 572 common mseal sys_mseal ++573 common process_ksm_enable sys_process_ksm_enable ++574 common process_ksm_disable sys_process_ksm_disable ++575 common process_ksm_status sys_process_ksm_status +diff --git a/arch/arm/tools/syscall.tbl b/arch/arm/tools/syscall.tbl +index 2ed7d229c8f9..3f59e9c5c1ff 100644 +--- a/arch/arm/tools/syscall.tbl ++++ b/arch/arm/tools/syscall.tbl +@@ -476,3 +476,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/arm64/include/asm/unistd.h b/arch/arm64/include/asm/unistd.h +index 1346579f802f..f3a77719eb05 100644 +--- a/arch/arm64/include/asm/unistd.h ++++ b/arch/arm64/include/asm/unistd.h +@@ -39,7 +39,7 @@ + #define __ARM_NR_compat_set_tls (__ARM_NR_COMPAT_BASE + 5) + #define __ARM_NR_COMPAT_END (__ARM_NR_COMPAT_BASE + 0x800) + +-#define __NR_compat_syscalls 463 ++#define __NR_compat_syscalls 466 + #endif + + #define __ARCH_WANT_SYS_CLONE +diff --git a/arch/arm64/include/asm/unistd32.h b/arch/arm64/include/asm/unistd32.h +index 1386e8e751f2..ccdc523fa4bd 100644 +--- a/arch/arm64/include/asm/unistd32.h ++++ b/arch/arm64/include/asm/unistd32.h +@@ -931,6 +931,12 @@ __SYSCALL(__NR_lsm_set_self_attr, sys_lsm_set_self_attr) + __SYSCALL(__NR_lsm_list_modules, sys_lsm_list_modules) + #define __NR_mseal 462 + __SYSCALL(__NR_mseal, sys_mseal) ++#define __NR_process_ksm_enable 463 ++__SYSCALL(__NR_process_ksm_enable, sys_process_ksm_enable) ++#define __NR_process_ksm_disable 464 ++__SYSCALL(__NR_process_ksm_disable, sys_process_ksm_disable) ++#define __NR_process_ksm_status 465 ++__SYSCALL(__NR_process_ksm_status, sys_process_ksm_status) + + /* + * Please add new compat syscalls above this comment and update +diff --git a/arch/m68k/kernel/syscalls/syscall.tbl b/arch/m68k/kernel/syscalls/syscall.tbl +index 22a3cbd4c602..12d2c7594bf0 100644 +--- a/arch/m68k/kernel/syscalls/syscall.tbl ++++ b/arch/m68k/kernel/syscalls/syscall.tbl +@@ -462,3 +462,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/microblaze/kernel/syscalls/syscall.tbl b/arch/microblaze/kernel/syscalls/syscall.tbl +index 2b81a6bd78b2..e2a93c856eed 100644 +--- a/arch/microblaze/kernel/syscalls/syscall.tbl ++++ b/arch/microblaze/kernel/syscalls/syscall.tbl +@@ -468,3 +468,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/mips/kernel/syscalls/syscall_n32.tbl b/arch/mips/kernel/syscalls/syscall_n32.tbl +index 953f5b7dc723..b921fbf56fa6 100644 +--- a/arch/mips/kernel/syscalls/syscall_n32.tbl ++++ b/arch/mips/kernel/syscalls/syscall_n32.tbl +@@ -401,3 +401,6 @@ + 460 n32 lsm_set_self_attr sys_lsm_set_self_attr + 461 n32 lsm_list_modules sys_lsm_list_modules + 462 n32 mseal sys_mseal ++463 n32 process_ksm_enable sys_process_ksm_enable ++464 n32 process_ksm_disable sys_process_ksm_disable ++465 n32 process_ksm_status sys_process_ksm_status +diff --git a/arch/mips/kernel/syscalls/syscall_n64.tbl b/arch/mips/kernel/syscalls/syscall_n64.tbl +index 1464c6be6eb3..8d7f9ddd66f4 100644 +--- a/arch/mips/kernel/syscalls/syscall_n64.tbl ++++ b/arch/mips/kernel/syscalls/syscall_n64.tbl +@@ -377,3 +377,6 @@ + 460 n64 lsm_set_self_attr sys_lsm_set_self_attr + 461 n64 lsm_list_modules sys_lsm_list_modules + 462 n64 mseal sys_mseal ++463 n64 process_ksm_enable sys_process_ksm_enable ++464 n64 process_ksm_disable sys_process_ksm_disable ++465 n64 process_ksm_status sys_process_ksm_status +diff --git a/arch/mips/kernel/syscalls/syscall_o32.tbl b/arch/mips/kernel/syscalls/syscall_o32.tbl +index 2439a2491cff..9d6142739954 100644 +--- a/arch/mips/kernel/syscalls/syscall_o32.tbl ++++ b/arch/mips/kernel/syscalls/syscall_o32.tbl +@@ -450,3 +450,6 @@ + 460 o32 lsm_set_self_attr sys_lsm_set_self_attr + 461 o32 lsm_list_modules sys_lsm_list_modules + 462 o32 mseal sys_mseal ++463 o32 process_ksm_enable sys_process_ksm_enable ++464 o32 process_ksm_disable sys_process_ksm_disable ++465 o32 process_ksm_status sys_process_ksm_status +diff --git a/arch/parisc/kernel/syscalls/syscall.tbl b/arch/parisc/kernel/syscalls/syscall.tbl +index 66dc406b12e4..9d46476fd908 100644 +--- a/arch/parisc/kernel/syscalls/syscall.tbl ++++ b/arch/parisc/kernel/syscalls/syscall.tbl +@@ -461,3 +461,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/powerpc/kernel/syscalls/syscall.tbl b/arch/powerpc/kernel/syscalls/syscall.tbl +index ebae8415dfbb..16f71bc2f6f0 100644 +--- a/arch/powerpc/kernel/syscalls/syscall.tbl ++++ b/arch/powerpc/kernel/syscalls/syscall.tbl +@@ -553,3 +553,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/s390/kernel/syscalls/syscall.tbl b/arch/s390/kernel/syscalls/syscall.tbl +index 01071182763e..7394bad8178e 100644 +--- a/arch/s390/kernel/syscalls/syscall.tbl ++++ b/arch/s390/kernel/syscalls/syscall.tbl +@@ -465,3 +465,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status sys_process_ksm_status +diff --git a/arch/sh/kernel/syscalls/syscall.tbl b/arch/sh/kernel/syscalls/syscall.tbl +index c55fd7696d40..b9fc31221b87 100644 +--- a/arch/sh/kernel/syscalls/syscall.tbl ++++ b/arch/sh/kernel/syscalls/syscall.tbl +@@ -466,3 +466,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/sparc/kernel/syscalls/syscall.tbl b/arch/sparc/kernel/syscalls/syscall.tbl +index cfdfb3707c16..0d79fd772854 100644 +--- a/arch/sparc/kernel/syscalls/syscall.tbl ++++ b/arch/sparc/kernel/syscalls/syscall.tbl +@@ -508,3 +508,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl +index 4b71a2607bf5..812085b21a3f 100644 +--- a/arch/x86/entry/syscalls/syscall_32.tbl ++++ b/arch/x86/entry/syscalls/syscall_32.tbl +@@ -467,3 +467,6 @@ + 460 i386 lsm_set_self_attr sys_lsm_set_self_attr + 461 i386 lsm_list_modules sys_lsm_list_modules + 462 i386 mseal sys_mseal ++463 i386 process_ksm_enable sys_process_ksm_enable ++464 i386 process_ksm_disable sys_process_ksm_disable ++465 i386 process_ksm_status sys_process_ksm_status +diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl +index a8068f937290..bc2e635389cc 100644 +--- a/arch/x86/entry/syscalls/syscall_64.tbl ++++ b/arch/x86/entry/syscalls/syscall_64.tbl +@@ -384,6 +384,9 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status + + # + # Due to a historical design error, certain syscalls are numbered differently +diff --git a/arch/xtensa/kernel/syscalls/syscall.tbl b/arch/xtensa/kernel/syscalls/syscall.tbl +index 67083fc1b2f5..c1aecee4ad9b 100644 +--- a/arch/xtensa/kernel/syscalls/syscall.tbl ++++ b/arch/xtensa/kernel/syscalls/syscall.tbl +@@ -433,3 +433,6 @@ + 460 common lsm_set_self_attr sys_lsm_set_self_attr + 461 common lsm_list_modules sys_lsm_list_modules + 462 common mseal sys_mseal ++463 common process_ksm_enable sys_process_ksm_enable ++464 common process_ksm_disable sys_process_ksm_disable ++465 common process_ksm_status sys_process_ksm_status +diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h +index fff820c3e93e..ab7d77ddc112 100644 +--- a/include/linux/syscalls.h ++++ b/include/linux/syscalls.h +@@ -818,6 +818,9 @@ asmlinkage long sys_madvise(unsigned long start, size_t len, int behavior); + asmlinkage long sys_process_madvise(int pidfd, const struct iovec __user *vec, + size_t vlen, int behavior, unsigned int flags); + asmlinkage long sys_process_mrelease(int pidfd, unsigned int flags); ++asmlinkage long sys_process_ksm_enable(int pidfd, unsigned int flags); ++asmlinkage long sys_process_ksm_disable(int pidfd, unsigned int flags); ++asmlinkage long sys_process_ksm_status(int pidfd, unsigned int flags); + asmlinkage long sys_remap_file_pages(unsigned long start, unsigned long size, + unsigned long prot, unsigned long pgoff, + unsigned long flags); +diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h +index d4cc26932ff4..d191548f6326 100644 +--- a/include/uapi/asm-generic/unistd.h ++++ b/include/uapi/asm-generic/unistd.h +@@ -845,8 +845,17 @@ __SYSCALL(__NR_lsm_list_modules, sys_lsm_list_modules) + #define __NR_mseal 462 + __SYSCALL(__NR_mseal, sys_mseal) + ++#define __NR_process_ksm_enable 463 ++__SYSCALL(__NR_process_ksm_enable, sys_process_ksm_enable) ++ ++#define __NR_process_ksm_disable 464 ++__SYSCALL(__NR_process_ksm_disable, sys_process_ksm_disable) ++ ++#define __NR_process_ksm_status 465 ++__SYSCALL(__NR_process_ksm_status, sys_process_ksm_status) ++ + #undef __NR_syscalls +-#define __NR_syscalls 463 ++#define __NR_syscalls 466 + + /* + * 32 bit systems traditionally used different +diff --git a/kernel/sys.c b/kernel/sys.c +index 3a2df1bd9f64..86c6dd9d8c84 100644 +--- a/kernel/sys.c ++++ b/kernel/sys.c +@@ -2789,6 +2789,153 @@ SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3, + return error; + } + ++#ifdef CONFIG_KSM ++enum pkc_action { ++ PKSM_ENABLE = 0, ++ PKSM_DISABLE, ++ PKSM_STATUS, ++}; ++ ++static long do_process_ksm_control(int pidfd, enum pkc_action action) ++{ ++ long ret; ++ struct pid *pid; ++ struct task_struct *task; ++ struct mm_struct *mm; ++ unsigned int f_flags; ++ ++ pid = pidfd_get_pid(pidfd, &f_flags); ++ if (IS_ERR(pid)) { ++ ret = PTR_ERR(pid); ++ goto out; ++ } ++ ++ task = get_pid_task(pid, PIDTYPE_PID); ++ if (!task) { ++ ret = -ESRCH; ++ goto put_pid; ++ } ++ ++ /* Require PTRACE_MODE_READ to avoid leaking ASLR metadata. */ ++ mm = mm_access(task, PTRACE_MODE_READ_FSCREDS); ++ if (IS_ERR_OR_NULL(mm)) { ++ ret = IS_ERR(mm) ? PTR_ERR(mm) : -ESRCH; ++ goto release_task; ++ } ++ ++ /* Require CAP_SYS_NICE for influencing process performance. */ ++ if (!capable(CAP_SYS_NICE)) { ++ ret = -EPERM; ++ goto release_mm; ++ } ++ ++ if (mmap_write_lock_killable(mm)) { ++ ret = -EINTR; ++ goto release_mm; ++ } ++ ++ switch (action) { ++ case PKSM_ENABLE: ++ ret = ksm_enable_merge_any(mm); ++ break; ++ case PKSM_DISABLE: ++ ret = ksm_disable_merge_any(mm); ++ break; ++ case PKSM_STATUS: ++ ret = !!test_bit(MMF_VM_MERGE_ANY, &mm->flags); ++ break; ++ } ++ ++ mmap_write_unlock(mm); ++ ++release_mm: ++ mmput(mm); ++release_task: ++ put_task_struct(task); ++put_pid: ++ put_pid(pid); ++out: ++ return ret; ++} ++#endif /* CONFIG_KSM */ ++ ++SYSCALL_DEFINE2(process_ksm_enable, int, pidfd, unsigned int, flags) ++{ ++#ifdef CONFIG_KSM ++ if (flags != 0) ++ return -EINVAL; ++ ++ return do_process_ksm_control(pidfd, PKSM_ENABLE); ++#else /* CONFIG_KSM */ ++ return -ENOSYS; ++#endif /* CONFIG_KSM */ ++} ++ ++SYSCALL_DEFINE2(process_ksm_disable, int, pidfd, unsigned int, flags) ++{ ++#ifdef CONFIG_KSM ++ if (flags != 0) ++ return -EINVAL; ++ ++ return do_process_ksm_control(pidfd, PKSM_DISABLE); ++#else /* CONFIG_KSM */ ++ return -ENOSYS; ++#endif /* CONFIG_KSM */ ++} ++ ++SYSCALL_DEFINE2(process_ksm_status, int, pidfd, unsigned int, flags) ++{ ++#ifdef CONFIG_KSM ++ if (flags != 0) ++ return -EINVAL; ++ ++ return do_process_ksm_control(pidfd, PKSM_STATUS); ++#else /* CONFIG_KSM */ ++ return -ENOSYS; ++#endif /* CONFIG_KSM */ ++} ++ ++#ifdef CONFIG_KSM ++static ssize_t process_ksm_enable_show(struct kobject *kobj, ++ struct kobj_attribute *attr, char *buf) ++{ ++ return sprintf(buf, "%u\n", __NR_process_ksm_enable); ++} ++static struct kobj_attribute process_ksm_enable_attr = __ATTR_RO(process_ksm_enable); ++ ++static ssize_t process_ksm_disable_show(struct kobject *kobj, ++ struct kobj_attribute *attr, char *buf) ++{ ++ return sprintf(buf, "%u\n", __NR_process_ksm_disable); ++} ++static struct kobj_attribute process_ksm_disable_attr = __ATTR_RO(process_ksm_disable); ++ ++static ssize_t process_ksm_status_show(struct kobject *kobj, ++ struct kobj_attribute *attr, char *buf) ++{ ++ return sprintf(buf, "%u\n", __NR_process_ksm_status); ++} ++static struct kobj_attribute process_ksm_status_attr = __ATTR_RO(process_ksm_status); ++ ++static struct attribute *process_ksm_sysfs_attrs[] = { ++ &process_ksm_enable_attr.attr, ++ &process_ksm_disable_attr.attr, ++ &process_ksm_status_attr.attr, ++ NULL, ++}; ++ ++static const struct attribute_group process_ksm_sysfs_attr_group = { ++ .attrs = process_ksm_sysfs_attrs, ++ .name = "process_ksm", ++}; ++ ++static int __init process_ksm_sysfs_init(void) ++{ ++ return sysfs_create_group(kernel_kobj, &process_ksm_sysfs_attr_group); ++} ++subsys_initcall(process_ksm_sysfs_init); ++#endif /* CONFIG_KSM */ ++ + SYSCALL_DEFINE3(getcpu, unsigned __user *, cpup, unsigned __user *, nodep, + struct getcpu_cache __user *, unused) + { +diff --git a/kernel/sys_ni.c b/kernel/sys_ni.c +index b696b85ac63e..cf7f3d841b1e 100644 +--- a/kernel/sys_ni.c ++++ b/kernel/sys_ni.c +@@ -188,6 +188,9 @@ COND_SYSCALL(mincore); + COND_SYSCALL(madvise); + COND_SYSCALL(process_madvise); + COND_SYSCALL(process_mrelease); ++COND_SYSCALL(process_ksm_enable); ++COND_SYSCALL(process_ksm_disable); ++COND_SYSCALL(process_ksm_status); + COND_SYSCALL(remap_file_pages); + COND_SYSCALL(mbind); + COND_SYSCALL(get_mempolicy); +-- +2.46.0 + +From f32682f3e50005e5e88aabb18aefac39fd4497a0 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:19:55 +0200 +Subject: [PATCH 09/12] ntsync + +Signed-off-by: Peter Jung +--- + Documentation/userspace-api/index.rst | 1 + + Documentation/userspace-api/ntsync.rst | 398 +++++ + MAINTAINERS | 9 + + drivers/misc/Kconfig | 1 - + drivers/misc/ntsync.c | 989 +++++++++++- + include/uapi/linux/ntsync.h | 39 + + tools/testing/selftests/Makefile | 1 + + .../selftests/drivers/ntsync/.gitignore | 1 + + .../testing/selftests/drivers/ntsync/Makefile | 7 + + tools/testing/selftests/drivers/ntsync/config | 1 + + .../testing/selftests/drivers/ntsync/ntsync.c | 1407 +++++++++++++++++ + 11 files changed, 2850 insertions(+), 4 deletions(-) + create mode 100644 Documentation/userspace-api/ntsync.rst + create mode 100644 tools/testing/selftests/drivers/ntsync/.gitignore + create mode 100644 tools/testing/selftests/drivers/ntsync/Makefile + create mode 100644 tools/testing/selftests/drivers/ntsync/config + create mode 100644 tools/testing/selftests/drivers/ntsync/ntsync.c + +diff --git a/Documentation/userspace-api/index.rst b/Documentation/userspace-api/index.rst +index 8a251d71fa6e..02bea81fb4bf 100644 +--- a/Documentation/userspace-api/index.rst ++++ b/Documentation/userspace-api/index.rst +@@ -64,6 +64,7 @@ Everything else + vduse + futex2 + perf_ring_buffer ++ ntsync + + .. only:: subproject and html + +diff --git a/Documentation/userspace-api/ntsync.rst b/Documentation/userspace-api/ntsync.rst +new file mode 100644 +index 000000000000..767844637a7d +--- /dev/null ++++ b/Documentation/userspace-api/ntsync.rst +@@ -0,0 +1,398 @@ ++=================================== ++NT synchronization primitive driver ++=================================== ++ ++This page documents the user-space API for the ntsync driver. ++ ++ntsync is a support driver for emulation of NT synchronization ++primitives by user-space NT emulators. It exists because implementation ++in user-space, using existing tools, cannot match Windows performance ++while offering accurate semantics. It is implemented entirely in ++software, and does not drive any hardware device. ++ ++This interface is meant as a compatibility tool only, and should not ++be used for general synchronization. Instead use generic, versatile ++interfaces such as futex(2) and poll(2). ++ ++Synchronization primitives ++========================== ++ ++The ntsync driver exposes three types of synchronization primitives: ++semaphores, mutexes, and events. ++ ++A semaphore holds a single volatile 32-bit counter, and a static 32-bit ++integer denoting the maximum value. It is considered signaled (that is, ++can be acquired without contention, or will wake up a waiting thread) ++when the counter is nonzero. The counter is decremented by one when a ++wait is satisfied. Both the initial and maximum count are established ++when the semaphore is created. ++ ++A mutex holds a volatile 32-bit recursion count, and a volatile 32-bit ++identifier denoting its owner. A mutex is considered signaled when its ++owner is zero (indicating that it is not owned). The recursion count is ++incremented when a wait is satisfied, and ownership is set to the given ++identifier. ++ ++A mutex also holds an internal flag denoting whether its previous owner ++has died; such a mutex is said to be abandoned. Owner death is not ++tracked automatically based on thread death, but rather must be ++communicated using ``NTSYNC_IOC_MUTEX_KILL``. An abandoned mutex is ++inherently considered unowned. ++ ++Except for the "unowned" semantics of zero, the actual value of the ++owner identifier is not interpreted by the ntsync driver at all. The ++intended use is to store a thread identifier; however, the ntsync ++driver does not actually validate that a calling thread provides ++consistent or unique identifiers. ++ ++An event is similar to a semaphore with a maximum count of one. It holds ++a volatile boolean state denoting whether it is signaled or not. There ++are two types of events, auto-reset and manual-reset. An auto-reset ++event is designaled when a wait is satisfied; a manual-reset event is ++not. The event type is specified when the event is created. ++ ++Unless specified otherwise, all operations on an object are atomic and ++totally ordered with respect to other operations on the same object. ++ ++Objects are represented by files. When all file descriptors to an ++object are closed, that object is deleted. ++ ++Char device ++=========== ++ ++The ntsync driver creates a single char device /dev/ntsync. Each file ++description opened on the device represents a unique instance intended ++to back an individual NT virtual machine. Objects created by one ntsync ++instance may only be used with other objects created by the same ++instance. ++ ++ioctl reference ++=============== ++ ++All operations on the device are done through ioctls. There are four ++structures used in ioctl calls:: ++ ++ struct ntsync_sem_args { ++ __u32 sem; ++ __u32 count; ++ __u32 max; ++ }; ++ ++ struct ntsync_mutex_args { ++ __u32 mutex; ++ __u32 owner; ++ __u32 count; ++ }; ++ ++ struct ntsync_event_args { ++ __u32 event; ++ __u32 signaled; ++ __u32 manual; ++ }; ++ ++ struct ntsync_wait_args { ++ __u64 timeout; ++ __u64 objs; ++ __u32 count; ++ __u32 owner; ++ __u32 index; ++ __u32 alert; ++ __u32 flags; ++ __u32 pad; ++ }; ++ ++Depending on the ioctl, members of the structure may be used as input, ++output, or not at all. All ioctls return 0 on success. ++ ++The ioctls on the device file are as follows: ++ ++.. c:macro:: NTSYNC_IOC_CREATE_SEM ++ ++ Create a semaphore object. Takes a pointer to struct ++ :c:type:`ntsync_sem_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``sem`` ++ - On output, contains a file descriptor to the created semaphore. ++ * - ``count`` ++ - Initial count of the semaphore. ++ * - ``max`` ++ - Maximum count of the semaphore. ++ ++ Fails with ``EINVAL`` if ``count`` is greater than ``max``. ++ ++.. c:macro:: NTSYNC_IOC_CREATE_MUTEX ++ ++ Create a mutex object. Takes a pointer to struct ++ :c:type:`ntsync_mutex_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``mutex`` ++ - On output, contains a file descriptor to the created mutex. ++ * - ``count`` ++ - Initial recursion count of the mutex. ++ * - ``owner`` ++ - Initial owner of the mutex. ++ ++ If ``owner`` is nonzero and ``count`` is zero, or if ``owner`` is ++ zero and ``count`` is nonzero, the function fails with ``EINVAL``. ++ ++.. c:macro:: NTSYNC_IOC_CREATE_EVENT ++ ++ Create an event object. Takes a pointer to struct ++ :c:type:`ntsync_event_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``event`` ++ - On output, contains a file descriptor to the created event. ++ * - ``signaled`` ++ - If nonzero, the event is initially signaled, otherwise ++ nonsignaled. ++ * - ``manual`` ++ - If nonzero, the event is a manual-reset event, otherwise ++ auto-reset. ++ ++The ioctls on the individual objects are as follows: ++ ++.. c:macro:: NTSYNC_IOC_SEM_POST ++ ++ Post to a semaphore object. Takes a pointer to a 32-bit integer, ++ which on input holds the count to be added to the semaphore, and on ++ output contains its previous count. ++ ++ If adding to the semaphore's current count would raise the latter ++ past the semaphore's maximum count, the ioctl fails with ++ ``EOVERFLOW`` and the semaphore is not affected. If raising the ++ semaphore's count causes it to become signaled, eligible threads ++ waiting on this semaphore will be woken and the semaphore's count ++ decremented appropriately. ++ ++.. c:macro:: NTSYNC_IOC_MUTEX_UNLOCK ++ ++ Release a mutex object. Takes a pointer to struct ++ :c:type:`ntsync_mutex_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``mutex`` ++ - Ignored. ++ * - ``owner`` ++ - Specifies the owner trying to release this mutex. ++ * - ``count`` ++ - On output, contains the previous recursion count. ++ ++ If ``owner`` is zero, the ioctl fails with ``EINVAL``. If ``owner`` ++ is not the current owner of the mutex, the ioctl fails with ++ ``EPERM``. ++ ++ The mutex's count will be decremented by one. If decrementing the ++ mutex's count causes it to become zero, the mutex is marked as ++ unowned and signaled, and eligible threads waiting on it will be ++ woken as appropriate. ++ ++.. c:macro:: NTSYNC_IOC_SET_EVENT ++ ++ Signal an event object. Takes a pointer to a 32-bit integer, which on ++ output contains the previous state of the event. ++ ++ Eligible threads will be woken, and auto-reset events will be ++ designaled appropriately. ++ ++.. c:macro:: NTSYNC_IOC_RESET_EVENT ++ ++ Designal an event object. Takes a pointer to a 32-bit integer, which ++ on output contains the previous state of the event. ++ ++.. c:macro:: NTSYNC_IOC_PULSE_EVENT ++ ++ Wake threads waiting on an event object while leaving it in an ++ unsignaled state. Takes a pointer to a 32-bit integer, which on ++ output contains the previous state of the event. ++ ++ A pulse operation can be thought of as a set followed by a reset, ++ performed as a single atomic operation. If two threads are waiting on ++ an auto-reset event which is pulsed, only one will be woken. If two ++ threads are waiting a manual-reset event which is pulsed, both will ++ be woken. However, in both cases, the event will be unsignaled ++ afterwards, and a simultaneous read operation will always report the ++ event as unsignaled. ++ ++.. c:macro:: NTSYNC_IOC_READ_SEM ++ ++ Read the current state of a semaphore object. Takes a pointer to ++ struct :c:type:`ntsync_sem_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``sem`` ++ - Ignored. ++ * - ``count`` ++ - On output, contains the current count of the semaphore. ++ * - ``max`` ++ - On output, contains the maximum count of the semaphore. ++ ++.. c:macro:: NTSYNC_IOC_READ_MUTEX ++ ++ Read the current state of a mutex object. Takes a pointer to struct ++ :c:type:`ntsync_mutex_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``mutex`` ++ - Ignored. ++ * - ``owner`` ++ - On output, contains the current owner of the mutex, or zero ++ if the mutex is not currently owned. ++ * - ``count`` ++ - On output, contains the current recursion count of the mutex. ++ ++ If the mutex is marked as abandoned, the function fails with ++ ``EOWNERDEAD``. In this case, ``count`` and ``owner`` are set to ++ zero. ++ ++.. c:macro:: NTSYNC_IOC_READ_EVENT ++ ++ Read the current state of an event object. Takes a pointer to struct ++ :c:type:`ntsync_event_args`, which is used as follows: ++ ++ .. list-table:: ++ ++ * - ``event`` ++ - Ignored. ++ * - ``signaled`` ++ - On output, contains the current state of the event. ++ * - ``manual`` ++ - On output, contains 1 if the event is a manual-reset event, ++ and 0 otherwise. ++ ++.. c:macro:: NTSYNC_IOC_KILL_OWNER ++ ++ Mark a mutex as unowned and abandoned if it is owned by the given ++ owner. Takes an input-only pointer to a 32-bit integer denoting the ++ owner. If the owner is zero, the ioctl fails with ``EINVAL``. If the ++ owner does not own the mutex, the function fails with ``EPERM``. ++ ++ Eligible threads waiting on the mutex will be woken as appropriate ++ (and such waits will fail with ``EOWNERDEAD``, as described below). ++ ++.. c:macro:: NTSYNC_IOC_WAIT_ANY ++ ++ Poll on any of a list of objects, atomically acquiring at most one. ++ Takes a pointer to struct :c:type:`ntsync_wait_args`, which is ++ used as follows: ++ ++ .. list-table:: ++ ++ * - ``timeout`` ++ - Absolute timeout in nanoseconds. If ``NTSYNC_WAIT_REALTIME`` ++ is set, the timeout is measured against the REALTIME clock; ++ otherwise it is measured against the MONOTONIC clock. If the ++ timeout is equal to or earlier than the current time, the ++ function returns immediately without sleeping. If ``timeout`` ++ is U64_MAX, the function will sleep until an object is ++ signaled, and will not fail with ``ETIMEDOUT``. ++ * - ``objs`` ++ - Pointer to an array of ``count`` file descriptors ++ (specified as an integer so that the structure has the same ++ size regardless of architecture). If any object is ++ invalid, the function fails with ``EINVAL``. ++ * - ``count`` ++ - Number of objects specified in the ``objs`` array. ++ If greater than ``NTSYNC_MAX_WAIT_COUNT``, the function fails ++ with ``EINVAL``. ++ * - ``owner`` ++ - Mutex owner identifier. If any object in ``objs`` is a mutex, ++ the ioctl will attempt to acquire that mutex on behalf of ++ ``owner``. If ``owner`` is zero, the ioctl fails with ++ ``EINVAL``. ++ * - ``index`` ++ - On success, contains the index (into ``objs``) of the object ++ which was signaled. If ``alert`` was signaled instead, ++ this contains ``count``. ++ * - ``alert`` ++ - Optional event object file descriptor. If nonzero, this ++ specifies an "alert" event object which, if signaled, will ++ terminate the wait. If nonzero, the identifier must point to a ++ valid event. ++ * - ``flags`` ++ - Zero or more flags. Currently the only flag is ++ ``NTSYNC_WAIT_REALTIME``, which causes the timeout to be ++ measured against the REALTIME clock instead of MONOTONIC. ++ * - ``pad`` ++ - Unused, must be set to zero. ++ ++ This function attempts to acquire one of the given objects. If unable ++ to do so, it sleeps until an object becomes signaled, subsequently ++ acquiring it, or the timeout expires. In the latter case the ioctl ++ fails with ``ETIMEDOUT``. The function only acquires one object, even ++ if multiple objects are signaled. ++ ++ A semaphore is considered to be signaled if its count is nonzero, and ++ is acquired by decrementing its count by one. A mutex is considered ++ to be signaled if it is unowned or if its owner matches the ``owner`` ++ argument, and is acquired by incrementing its recursion count by one ++ and setting its owner to the ``owner`` argument. An auto-reset event ++ is acquired by designaling it; a manual-reset event is not affected ++ by acquisition. ++ ++ Acquisition is atomic and totally ordered with respect to other ++ operations on the same object. If two wait operations (with different ++ ``owner`` identifiers) are queued on the same mutex, only one is ++ signaled. If two wait operations are queued on the same semaphore, ++ and a value of one is posted to it, only one is signaled. ++ ++ If an abandoned mutex is acquired, the ioctl fails with ++ ``EOWNERDEAD``. Although this is a failure return, the function may ++ otherwise be considered successful. The mutex is marked as owned by ++ the given owner (with a recursion count of 1) and as no longer ++ abandoned, and ``index`` is still set to the index of the mutex. ++ ++ The ``alert`` argument is an "extra" event which can terminate the ++ wait, independently of all other objects. ++ ++ It is valid to pass the same object more than once, including by ++ passing the same event in the ``objs`` array and in ``alert``. If a ++ wakeup occurs due to that object being signaled, ``index`` is set to ++ the lowest index corresponding to that object. ++ ++ The function may fail with ``EINTR`` if a signal is received. ++ ++.. c:macro:: NTSYNC_IOC_WAIT_ALL ++ ++ Poll on a list of objects, atomically acquiring all of them. Takes a ++ pointer to struct :c:type:`ntsync_wait_args`, which is used ++ identically to ``NTSYNC_IOC_WAIT_ANY``, except that ``index`` is ++ always filled with zero on success if not woken via alert. ++ ++ This function attempts to simultaneously acquire all of the given ++ objects. If unable to do so, it sleeps until all objects become ++ simultaneously signaled, subsequently acquiring them, or the timeout ++ expires. In the latter case the ioctl fails with ``ETIMEDOUT`` and no ++ objects are modified. ++ ++ Objects may become signaled and subsequently designaled (through ++ acquisition by other threads) while this thread is sleeping. Only ++ once all objects are simultaneously signaled does the ioctl acquire ++ them and return. The entire acquisition is atomic and totally ordered ++ with respect to other operations on any of the given objects. ++ ++ If an abandoned mutex is acquired, the ioctl fails with ++ ``EOWNERDEAD``. Similarly to ``NTSYNC_IOC_WAIT_ANY``, all objects are ++ nevertheless marked as acquired. Note that if multiple mutex objects ++ are specified, there is no way to know which were marked as ++ abandoned. ++ ++ As with "any" waits, the ``alert`` argument is an "extra" event which ++ can terminate the wait. Critically, however, an "all" wait will ++ succeed if all members in ``objs`` are signaled, *or* if ``alert`` is ++ signaled. In the latter case ``index`` will be set to ``count``. As ++ with "any" waits, if both conditions are filled, the former takes ++ priority, and objects in ``objs`` will be acquired. ++ ++ Unlike ``NTSYNC_IOC_WAIT_ANY``, it is not valid to pass the same ++ object more than once, nor is it valid to pass the same object in ++ ``objs`` and in ``alert``. If this is attempted, the function fails ++ with ``EINVAL``. +diff --git a/MAINTAINERS b/MAINTAINERS +index b27470be2e6a..4112729fc23a 100644 +--- a/MAINTAINERS ++++ b/MAINTAINERS +@@ -15983,6 +15983,15 @@ T: git https://github.com/Paragon-Software-Group/linux-ntfs3.git + F: Documentation/filesystems/ntfs3.rst + F: fs/ntfs3/ + ++NTSYNC SYNCHRONIZATION PRIMITIVE DRIVER ++M: Elizabeth Figura ++L: wine-devel@winehq.org ++S: Supported ++F: Documentation/userspace-api/ntsync.rst ++F: drivers/misc/ntsync.c ++F: include/uapi/linux/ntsync.h ++F: tools/testing/selftests/drivers/ntsync/ ++ + NUBUS SUBSYSTEM + M: Finn Thain + L: linux-m68k@lists.linux-m68k.org +diff --git a/drivers/misc/Kconfig b/drivers/misc/Kconfig +index faf983680040..2907b5c23368 100644 +--- a/drivers/misc/Kconfig ++++ b/drivers/misc/Kconfig +@@ -507,7 +507,6 @@ config OPEN_DICE + + config NTSYNC + tristate "NT synchronization primitive emulation" +- depends on BROKEN + help + This module provides kernel support for emulation of Windows NT + synchronization primitives. It is not a hardware driver. +diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c +index 3c2f743c58b0..87a24798a5c7 100644 +--- a/drivers/misc/ntsync.c ++++ b/drivers/misc/ntsync.c +@@ -6,11 +6,17 @@ + */ + + #include ++#include + #include + #include ++#include ++#include + #include + #include ++#include + #include ++#include ++#include + #include + #include + #include +@@ -19,6 +25,8 @@ + + enum ntsync_type { + NTSYNC_TYPE_SEM, ++ NTSYNC_TYPE_MUTEX, ++ NTSYNC_TYPE_EVENT, + }; + + /* +@@ -30,10 +38,13 @@ enum ntsync_type { + * + * Both rely on struct file for reference counting. Individual + * ntsync_obj objects take a reference to the device when created. ++ * Wait operations take a reference to each object being waited on for ++ * the duration of the wait. + */ + + struct ntsync_obj { + spinlock_t lock; ++ int dev_locked; + + enum ntsync_type type; + +@@ -46,13 +57,335 @@ struct ntsync_obj { + __u32 count; + __u32 max; + } sem; ++ struct { ++ __u32 count; ++ pid_t owner; ++ bool ownerdead; ++ } mutex; ++ struct { ++ bool manual; ++ bool signaled; ++ } event; + } u; ++ ++ /* ++ * any_waiters is protected by the object lock, but all_waiters is ++ * protected by the device wait_all_lock. ++ */ ++ struct list_head any_waiters; ++ struct list_head all_waiters; ++ ++ /* ++ * Hint describing how many tasks are queued on this object in a ++ * wait-all operation. ++ * ++ * Any time we do a wake, we may need to wake "all" waiters as well as ++ * "any" waiters. In order to atomically wake "all" waiters, we must ++ * lock all of the objects, and that means grabbing the wait_all_lock ++ * below (and, due to lock ordering rules, before locking this object). ++ * However, wait-all is a rare operation, and grabbing the wait-all ++ * lock for every wake would create unnecessary contention. ++ * Therefore we first check whether all_hint is zero, and, if it is, ++ * we skip trying to wake "all" waiters. ++ * ++ * Since wait requests must originate from user-space threads, we're ++ * limited here by PID_MAX_LIMIT, so there's no risk of overflow. ++ */ ++ atomic_t all_hint; ++}; ++ ++struct ntsync_q_entry { ++ struct list_head node; ++ struct ntsync_q *q; ++ struct ntsync_obj *obj; ++ __u32 index; ++}; ++ ++struct ntsync_q { ++ struct task_struct *task; ++ __u32 owner; ++ ++ /* ++ * Protected via atomic_try_cmpxchg(). Only the thread that wins the ++ * compare-and-swap may actually change object states and wake this ++ * task. ++ */ ++ atomic_t signaled; ++ ++ bool all; ++ bool ownerdead; ++ __u32 count; ++ struct ntsync_q_entry entries[]; + }; + + struct ntsync_device { ++ /* ++ * Wait-all operations must atomically grab all objects, and be totally ++ * ordered with respect to each other and wait-any operations. ++ * If one thread is trying to acquire several objects, another thread ++ * cannot touch the object at the same time. ++ * ++ * This device-wide lock is used to serialize wait-for-all ++ * operations, and operations on an object that is involved in a ++ * wait-for-all. ++ */ ++ struct mutex wait_all_lock; ++ + struct file *file; + }; + ++/* ++ * Single objects are locked using obj->lock. ++ * ++ * Multiple objects are 'locked' while holding dev->wait_all_lock. ++ * In this case however, individual objects are not locked by holding ++ * obj->lock, but by setting obj->dev_locked. ++ * ++ * This means that in order to lock a single object, the sequence is slightly ++ * more complicated than usual. Specifically it needs to check obj->dev_locked ++ * after acquiring obj->lock, if set, it needs to drop the lock and acquire ++ * dev->wait_all_lock in order to serialize against the multi-object operation. ++ */ ++ ++static void dev_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj) ++{ ++ lockdep_assert_held(&dev->wait_all_lock); ++ lockdep_assert(obj->dev == dev); ++ spin_lock(&obj->lock); ++ /* ++ * By setting obj->dev_locked inside obj->lock, it is ensured that ++ * anyone holding obj->lock must see the value. ++ */ ++ obj->dev_locked = 1; ++ spin_unlock(&obj->lock); ++} ++ ++static void dev_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj) ++{ ++ lockdep_assert_held(&dev->wait_all_lock); ++ lockdep_assert(obj->dev == dev); ++ spin_lock(&obj->lock); ++ obj->dev_locked = 0; ++ spin_unlock(&obj->lock); ++} ++ ++static void obj_lock(struct ntsync_obj *obj) ++{ ++ struct ntsync_device *dev = obj->dev; ++ ++ for (;;) { ++ spin_lock(&obj->lock); ++ if (likely(!obj->dev_locked)) ++ break; ++ ++ spin_unlock(&obj->lock); ++ mutex_lock(&dev->wait_all_lock); ++ spin_lock(&obj->lock); ++ /* ++ * obj->dev_locked should be set and released under the same ++ * wait_all_lock section, since we now own this lock, it should ++ * be clear. ++ */ ++ lockdep_assert(!obj->dev_locked); ++ spin_unlock(&obj->lock); ++ mutex_unlock(&dev->wait_all_lock); ++ } ++} ++ ++static void obj_unlock(struct ntsync_obj *obj) ++{ ++ spin_unlock(&obj->lock); ++} ++ ++static bool ntsync_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj) ++{ ++ bool all; ++ ++ obj_lock(obj); ++ all = atomic_read(&obj->all_hint); ++ if (unlikely(all)) { ++ obj_unlock(obj); ++ mutex_lock(&dev->wait_all_lock); ++ dev_lock_obj(dev, obj); ++ } ++ ++ return all; ++} ++ ++static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj, bool all) ++{ ++ if (all) { ++ dev_unlock_obj(dev, obj); ++ mutex_unlock(&dev->wait_all_lock); ++ } else { ++ obj_unlock(obj); ++ } ++} ++ ++#define ntsync_assert_held(obj) \ ++ lockdep_assert((lockdep_is_held(&(obj)->lock) != LOCK_STATE_NOT_HELD) || \ ++ ((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \ ++ (obj)->dev_locked)) ++ ++static bool is_signaled(struct ntsync_obj *obj, __u32 owner) ++{ ++ ntsync_assert_held(obj); ++ ++ switch (obj->type) { ++ case NTSYNC_TYPE_SEM: ++ return !!obj->u.sem.count; ++ case NTSYNC_TYPE_MUTEX: ++ if (obj->u.mutex.owner && obj->u.mutex.owner != owner) ++ return false; ++ return obj->u.mutex.count < UINT_MAX; ++ case NTSYNC_TYPE_EVENT: ++ return obj->u.event.signaled; ++ } ++ ++ WARN(1, "bad object type %#x\n", obj->type); ++ return false; ++} ++ ++/* ++ * "locked_obj" is an optional pointer to an object which is already locked and ++ * should not be locked again. This is necessary so that changing an object's ++ * state and waking it can be a single atomic operation. ++ */ ++static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q, ++ struct ntsync_obj *locked_obj) ++{ ++ __u32 count = q->count; ++ bool can_wake = true; ++ int signaled = -1; ++ __u32 i; ++ ++ lockdep_assert_held(&dev->wait_all_lock); ++ if (locked_obj) ++ lockdep_assert(locked_obj->dev_locked); ++ ++ for (i = 0; i < count; i++) { ++ if (q->entries[i].obj != locked_obj) ++ dev_lock_obj(dev, q->entries[i].obj); ++ } ++ ++ for (i = 0; i < count; i++) { ++ if (!is_signaled(q->entries[i].obj, q->owner)) { ++ can_wake = false; ++ break; ++ } ++ } ++ ++ if (can_wake && atomic_try_cmpxchg(&q->signaled, &signaled, 0)) { ++ for (i = 0; i < count; i++) { ++ struct ntsync_obj *obj = q->entries[i].obj; ++ ++ switch (obj->type) { ++ case NTSYNC_TYPE_SEM: ++ obj->u.sem.count--; ++ break; ++ case NTSYNC_TYPE_MUTEX: ++ if (obj->u.mutex.ownerdead) ++ q->ownerdead = true; ++ obj->u.mutex.ownerdead = false; ++ obj->u.mutex.count++; ++ obj->u.mutex.owner = q->owner; ++ break; ++ case NTSYNC_TYPE_EVENT: ++ if (!obj->u.event.manual) ++ obj->u.event.signaled = false; ++ break; ++ } ++ } ++ wake_up_process(q->task); ++ } ++ ++ for (i = 0; i < count; i++) { ++ if (q->entries[i].obj != locked_obj) ++ dev_unlock_obj(dev, q->entries[i].obj); ++ } ++} ++ ++static void try_wake_all_obj(struct ntsync_device *dev, struct ntsync_obj *obj) ++{ ++ struct ntsync_q_entry *entry; ++ ++ lockdep_assert_held(&dev->wait_all_lock); ++ lockdep_assert(obj->dev_locked); ++ ++ list_for_each_entry(entry, &obj->all_waiters, node) ++ try_wake_all(dev, entry->q, obj); ++} ++ ++static void try_wake_any_sem(struct ntsync_obj *sem) ++{ ++ struct ntsync_q_entry *entry; ++ ++ ntsync_assert_held(sem); ++ lockdep_assert(sem->type == NTSYNC_TYPE_SEM); ++ ++ list_for_each_entry(entry, &sem->any_waiters, node) { ++ struct ntsync_q *q = entry->q; ++ int signaled = -1; ++ ++ if (!sem->u.sem.count) ++ break; ++ ++ if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { ++ sem->u.sem.count--; ++ wake_up_process(q->task); ++ } ++ } ++} ++ ++static void try_wake_any_mutex(struct ntsync_obj *mutex) ++{ ++ struct ntsync_q_entry *entry; ++ ++ ntsync_assert_held(mutex); ++ lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX); ++ ++ list_for_each_entry(entry, &mutex->any_waiters, node) { ++ struct ntsync_q *q = entry->q; ++ int signaled = -1; ++ ++ if (mutex->u.mutex.count == UINT_MAX) ++ break; ++ if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner) ++ continue; ++ ++ if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { ++ if (mutex->u.mutex.ownerdead) ++ q->ownerdead = true; ++ mutex->u.mutex.ownerdead = false; ++ mutex->u.mutex.count++; ++ mutex->u.mutex.owner = q->owner; ++ wake_up_process(q->task); ++ } ++ } ++} ++ ++static void try_wake_any_event(struct ntsync_obj *event) ++{ ++ struct ntsync_q_entry *entry; ++ ++ ntsync_assert_held(event); ++ lockdep_assert(event->type == NTSYNC_TYPE_EVENT); ++ ++ list_for_each_entry(entry, &event->any_waiters, node) { ++ struct ntsync_q *q = entry->q; ++ int signaled = -1; ++ ++ if (!event->u.event.signaled) ++ break; ++ ++ if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { ++ if (!event->u.event.manual) ++ event->u.event.signaled = false; ++ wake_up_process(q->task); ++ } ++ } ++} ++ + /* + * Actually change the semaphore state, returning -EOVERFLOW if it is made + * invalid. +@@ -61,7 +394,7 @@ static int post_sem_state(struct ntsync_obj *sem, __u32 count) + { + __u32 sum; + +- lockdep_assert_held(&sem->lock); ++ ntsync_assert_held(sem); + + if (check_add_overflow(sem->u.sem.count, count, &sum) || + sum > sem->u.sem.max) +@@ -73,9 +406,11 @@ static int post_sem_state(struct ntsync_obj *sem, __u32 count) + + static int ntsync_sem_post(struct ntsync_obj *sem, void __user *argp) + { ++ struct ntsync_device *dev = sem->dev; + __u32 __user *user_args = argp; + __u32 prev_count; + __u32 args; ++ bool all; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) +@@ -84,12 +419,17 @@ static int ntsync_sem_post(struct ntsync_obj *sem, void __user *argp) + if (sem->type != NTSYNC_TYPE_SEM) + return -EINVAL; + +- spin_lock(&sem->lock); ++ all = ntsync_lock_obj(dev, sem); + + prev_count = sem->u.sem.count; + ret = post_sem_state(sem, args); ++ if (!ret) { ++ if (all) ++ try_wake_all_obj(dev, sem); ++ try_wake_any_sem(sem); ++ } + +- spin_unlock(&sem->lock); ++ ntsync_unlock_obj(dev, sem, all); + + if (!ret && put_user(prev_count, user_args)) + ret = -EFAULT; +@@ -97,6 +437,226 @@ static int ntsync_sem_post(struct ntsync_obj *sem, void __user *argp) + return ret; + } + ++/* ++ * Actually change the mutex state, returning -EPERM if not the owner. ++ */ ++static int unlock_mutex_state(struct ntsync_obj *mutex, ++ const struct ntsync_mutex_args *args) ++{ ++ ntsync_assert_held(mutex); ++ ++ if (mutex->u.mutex.owner != args->owner) ++ return -EPERM; ++ ++ if (!--mutex->u.mutex.count) ++ mutex->u.mutex.owner = 0; ++ return 0; ++} ++ ++static int ntsync_mutex_unlock(struct ntsync_obj *mutex, void __user *argp) ++{ ++ struct ntsync_mutex_args __user *user_args = argp; ++ struct ntsync_device *dev = mutex->dev; ++ struct ntsync_mutex_args args; ++ __u32 prev_count; ++ bool all; ++ int ret; ++ ++ if (copy_from_user(&args, argp, sizeof(args))) ++ return -EFAULT; ++ if (!args.owner) ++ return -EINVAL; ++ ++ if (mutex->type != NTSYNC_TYPE_MUTEX) ++ return -EINVAL; ++ ++ all = ntsync_lock_obj(dev, mutex); ++ ++ prev_count = mutex->u.mutex.count; ++ ret = unlock_mutex_state(mutex, &args); ++ if (!ret) { ++ if (all) ++ try_wake_all_obj(dev, mutex); ++ try_wake_any_mutex(mutex); ++ } ++ ++ ntsync_unlock_obj(dev, mutex, all); ++ ++ if (!ret && put_user(prev_count, &user_args->count)) ++ ret = -EFAULT; ++ ++ return ret; ++} ++ ++/* ++ * Actually change the mutex state to mark its owner as dead, ++ * returning -EPERM if not the owner. ++ */ ++static int kill_mutex_state(struct ntsync_obj *mutex, __u32 owner) ++{ ++ ntsync_assert_held(mutex); ++ ++ if (mutex->u.mutex.owner != owner) ++ return -EPERM; ++ ++ mutex->u.mutex.ownerdead = true; ++ mutex->u.mutex.owner = 0; ++ mutex->u.mutex.count = 0; ++ return 0; ++} ++ ++static int ntsync_mutex_kill(struct ntsync_obj *mutex, void __user *argp) ++{ ++ struct ntsync_device *dev = mutex->dev; ++ __u32 owner; ++ bool all; ++ int ret; ++ ++ if (get_user(owner, (__u32 __user *)argp)) ++ return -EFAULT; ++ if (!owner) ++ return -EINVAL; ++ ++ if (mutex->type != NTSYNC_TYPE_MUTEX) ++ return -EINVAL; ++ ++ all = ntsync_lock_obj(dev, mutex); ++ ++ ret = kill_mutex_state(mutex, owner); ++ if (!ret) { ++ if (all) ++ try_wake_all_obj(dev, mutex); ++ try_wake_any_mutex(mutex); ++ } ++ ++ ntsync_unlock_obj(dev, mutex, all); ++ ++ return ret; ++} ++ ++static int ntsync_event_set(struct ntsync_obj *event, void __user *argp, bool pulse) ++{ ++ struct ntsync_device *dev = event->dev; ++ __u32 prev_state; ++ bool all; ++ ++ if (event->type != NTSYNC_TYPE_EVENT) ++ return -EINVAL; ++ ++ all = ntsync_lock_obj(dev, event); ++ ++ prev_state = event->u.event.signaled; ++ event->u.event.signaled = true; ++ if (all) ++ try_wake_all_obj(dev, event); ++ try_wake_any_event(event); ++ if (pulse) ++ event->u.event.signaled = false; ++ ++ ntsync_unlock_obj(dev, event, all); ++ ++ if (put_user(prev_state, (__u32 __user *)argp)) ++ return -EFAULT; ++ ++ return 0; ++} ++ ++static int ntsync_event_reset(struct ntsync_obj *event, void __user *argp) ++{ ++ struct ntsync_device *dev = event->dev; ++ __u32 prev_state; ++ bool all; ++ ++ if (event->type != NTSYNC_TYPE_EVENT) ++ return -EINVAL; ++ ++ all = ntsync_lock_obj(dev, event); ++ ++ prev_state = event->u.event.signaled; ++ event->u.event.signaled = false; ++ ++ ntsync_unlock_obj(dev, event, all); ++ ++ if (put_user(prev_state, (__u32 __user *)argp)) ++ return -EFAULT; ++ ++ return 0; ++} ++ ++static int ntsync_sem_read(struct ntsync_obj *sem, void __user *argp) ++{ ++ struct ntsync_sem_args __user *user_args = argp; ++ struct ntsync_device *dev = sem->dev; ++ struct ntsync_sem_args args; ++ bool all; ++ ++ if (sem->type != NTSYNC_TYPE_SEM) ++ return -EINVAL; ++ ++ args.sem = 0; ++ ++ all = ntsync_lock_obj(dev, sem); ++ ++ args.count = sem->u.sem.count; ++ args.max = sem->u.sem.max; ++ ++ ntsync_unlock_obj(dev, sem, all); ++ ++ if (copy_to_user(user_args, &args, sizeof(args))) ++ return -EFAULT; ++ return 0; ++} ++ ++static int ntsync_mutex_read(struct ntsync_obj *mutex, void __user *argp) ++{ ++ struct ntsync_mutex_args __user *user_args = argp; ++ struct ntsync_device *dev = mutex->dev; ++ struct ntsync_mutex_args args; ++ bool all; ++ int ret; ++ ++ if (mutex->type != NTSYNC_TYPE_MUTEX) ++ return -EINVAL; ++ ++ args.mutex = 0; ++ ++ all = ntsync_lock_obj(dev, mutex); ++ ++ args.count = mutex->u.mutex.count; ++ args.owner = mutex->u.mutex.owner; ++ ret = mutex->u.mutex.ownerdead ? -EOWNERDEAD : 0; ++ ++ ntsync_unlock_obj(dev, mutex, all); ++ ++ if (copy_to_user(user_args, &args, sizeof(args))) ++ return -EFAULT; ++ return ret; ++} ++ ++static int ntsync_event_read(struct ntsync_obj *event, void __user *argp) ++{ ++ struct ntsync_event_args __user *user_args = argp; ++ struct ntsync_device *dev = event->dev; ++ struct ntsync_event_args args; ++ bool all; ++ ++ if (event->type != NTSYNC_TYPE_EVENT) ++ return -EINVAL; ++ ++ args.event = 0; ++ ++ all = ntsync_lock_obj(dev, event); ++ ++ args.manual = event->u.event.manual; ++ args.signaled = event->u.event.signaled; ++ ++ ntsync_unlock_obj(dev, event, all); ++ ++ if (copy_to_user(user_args, &args, sizeof(args))) ++ return -EFAULT; ++ return 0; ++} ++ + static int ntsync_obj_release(struct inode *inode, struct file *file) + { + struct ntsync_obj *obj = file->private_data; +@@ -116,6 +676,22 @@ static long ntsync_obj_ioctl(struct file *file, unsigned int cmd, + switch (cmd) { + case NTSYNC_IOC_SEM_POST: + return ntsync_sem_post(obj, argp); ++ case NTSYNC_IOC_SEM_READ: ++ return ntsync_sem_read(obj, argp); ++ case NTSYNC_IOC_MUTEX_UNLOCK: ++ return ntsync_mutex_unlock(obj, argp); ++ case NTSYNC_IOC_MUTEX_KILL: ++ return ntsync_mutex_kill(obj, argp); ++ case NTSYNC_IOC_MUTEX_READ: ++ return ntsync_mutex_read(obj, argp); ++ case NTSYNC_IOC_EVENT_SET: ++ return ntsync_event_set(obj, argp, false); ++ case NTSYNC_IOC_EVENT_RESET: ++ return ntsync_event_reset(obj, argp); ++ case NTSYNC_IOC_EVENT_PULSE: ++ return ntsync_event_set(obj, argp, true); ++ case NTSYNC_IOC_EVENT_READ: ++ return ntsync_event_read(obj, argp); + default: + return -ENOIOCTLCMD; + } +@@ -141,6 +717,9 @@ static struct ntsync_obj *ntsync_alloc_obj(struct ntsync_device *dev, + obj->dev = dev; + get_file(dev->file); + spin_lock_init(&obj->lock); ++ INIT_LIST_HEAD(&obj->any_waiters); ++ INIT_LIST_HEAD(&obj->all_waiters); ++ atomic_set(&obj->all_hint, 0); + + return obj; + } +@@ -191,6 +770,400 @@ static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp) + return put_user(fd, &user_args->sem); + } + ++static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp) ++{ ++ struct ntsync_mutex_args __user *user_args = argp; ++ struct ntsync_mutex_args args; ++ struct ntsync_obj *mutex; ++ int fd; ++ ++ if (copy_from_user(&args, argp, sizeof(args))) ++ return -EFAULT; ++ ++ if (!args.owner != !args.count) ++ return -EINVAL; ++ ++ mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX); ++ if (!mutex) ++ return -ENOMEM; ++ mutex->u.mutex.count = args.count; ++ mutex->u.mutex.owner = args.owner; ++ fd = ntsync_obj_get_fd(mutex); ++ if (fd < 0) { ++ kfree(mutex); ++ return fd; ++ } ++ ++ return put_user(fd, &user_args->mutex); ++} ++ ++static int ntsync_create_event(struct ntsync_device *dev, void __user *argp) ++{ ++ struct ntsync_event_args __user *user_args = argp; ++ struct ntsync_event_args args; ++ struct ntsync_obj *event; ++ int fd; ++ ++ if (copy_from_user(&args, argp, sizeof(args))) ++ return -EFAULT; ++ ++ event = ntsync_alloc_obj(dev, NTSYNC_TYPE_EVENT); ++ if (!event) ++ return -ENOMEM; ++ event->u.event.manual = args.manual; ++ event->u.event.signaled = args.signaled; ++ fd = ntsync_obj_get_fd(event); ++ if (fd < 0) { ++ kfree(event); ++ return fd; ++ } ++ ++ return put_user(fd, &user_args->event); ++} ++ ++static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd) ++{ ++ struct file *file = fget(fd); ++ struct ntsync_obj *obj; ++ ++ if (!file) ++ return NULL; ++ ++ if (file->f_op != &ntsync_obj_fops) { ++ fput(file); ++ return NULL; ++ } ++ ++ obj = file->private_data; ++ if (obj->dev != dev) { ++ fput(file); ++ return NULL; ++ } ++ ++ return obj; ++} ++ ++static void put_obj(struct ntsync_obj *obj) ++{ ++ fput(obj->file); ++} ++ ++static int ntsync_schedule(const struct ntsync_q *q, const struct ntsync_wait_args *args) ++{ ++ ktime_t timeout = ns_to_ktime(args->timeout); ++ clockid_t clock = CLOCK_MONOTONIC; ++ ktime_t *timeout_ptr; ++ int ret = 0; ++ ++ timeout_ptr = (args->timeout == U64_MAX ? NULL : &timeout); ++ ++ if (args->flags & NTSYNC_WAIT_REALTIME) ++ clock = CLOCK_REALTIME; ++ ++ do { ++ if (signal_pending(current)) { ++ ret = -ERESTARTSYS; ++ break; ++ } ++ ++ set_current_state(TASK_INTERRUPTIBLE); ++ if (atomic_read(&q->signaled) != -1) { ++ ret = 0; ++ break; ++ } ++ ret = schedule_hrtimeout_range_clock(timeout_ptr, 0, HRTIMER_MODE_ABS, clock); ++ } while (ret < 0); ++ __set_current_state(TASK_RUNNING); ++ ++ return ret; ++} ++ ++/* ++ * Allocate and initialize the ntsync_q structure, but do not queue us yet. ++ */ ++static int setup_wait(struct ntsync_device *dev, ++ const struct ntsync_wait_args *args, bool all, ++ struct ntsync_q **ret_q) ++{ ++ int fds[NTSYNC_MAX_WAIT_COUNT + 1]; ++ const __u32 count = args->count; ++ struct ntsync_q *q; ++ __u32 total_count; ++ __u32 i, j; ++ ++ if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME)) ++ return -EINVAL; ++ ++ if (args->count > NTSYNC_MAX_WAIT_COUNT) ++ return -EINVAL; ++ ++ total_count = count; ++ if (args->alert) ++ total_count++; ++ ++ if (copy_from_user(fds, u64_to_user_ptr(args->objs), ++ array_size(count, sizeof(*fds)))) ++ return -EFAULT; ++ if (args->alert) ++ fds[count] = args->alert; ++ ++ q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL); ++ if (!q) ++ return -ENOMEM; ++ q->task = current; ++ q->owner = args->owner; ++ atomic_set(&q->signaled, -1); ++ q->all = all; ++ q->ownerdead = false; ++ q->count = count; ++ ++ for (i = 0; i < total_count; i++) { ++ struct ntsync_q_entry *entry = &q->entries[i]; ++ struct ntsync_obj *obj = get_obj(dev, fds[i]); ++ ++ if (!obj) ++ goto err; ++ ++ if (all) { ++ /* Check that the objects are all distinct. */ ++ for (j = 0; j < i; j++) { ++ if (obj == q->entries[j].obj) { ++ put_obj(obj); ++ goto err; ++ } ++ } ++ } ++ ++ entry->obj = obj; ++ entry->q = q; ++ entry->index = i; ++ } ++ ++ *ret_q = q; ++ return 0; ++ ++err: ++ for (j = 0; j < i; j++) ++ put_obj(q->entries[j].obj); ++ kfree(q); ++ return -EINVAL; ++} ++ ++static void try_wake_any_obj(struct ntsync_obj *obj) ++{ ++ switch (obj->type) { ++ case NTSYNC_TYPE_SEM: ++ try_wake_any_sem(obj); ++ break; ++ case NTSYNC_TYPE_MUTEX: ++ try_wake_any_mutex(obj); ++ break; ++ case NTSYNC_TYPE_EVENT: ++ try_wake_any_event(obj); ++ break; ++ } ++} ++ ++static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp) ++{ ++ struct ntsync_wait_args args; ++ __u32 i, total_count; ++ struct ntsync_q *q; ++ int signaled; ++ bool all; ++ int ret; ++ ++ if (copy_from_user(&args, argp, sizeof(args))) ++ return -EFAULT; ++ ++ ret = setup_wait(dev, &args, false, &q); ++ if (ret < 0) ++ return ret; ++ ++ total_count = args.count; ++ if (args.alert) ++ total_count++; ++ ++ /* queue ourselves */ ++ ++ for (i = 0; i < total_count; i++) { ++ struct ntsync_q_entry *entry = &q->entries[i]; ++ struct ntsync_obj *obj = entry->obj; ++ ++ all = ntsync_lock_obj(dev, obj); ++ list_add_tail(&entry->node, &obj->any_waiters); ++ ntsync_unlock_obj(dev, obj, all); ++ } ++ ++ /* ++ * Check if we are already signaled. ++ * ++ * Note that the API requires that normal objects are checked before ++ * the alert event. Hence we queue the alert event last, and check ++ * objects in order. ++ */ ++ ++ for (i = 0; i < total_count; i++) { ++ struct ntsync_obj *obj = q->entries[i].obj; ++ ++ if (atomic_read(&q->signaled) != -1) ++ break; ++ ++ all = ntsync_lock_obj(dev, obj); ++ try_wake_any_obj(obj); ++ ntsync_unlock_obj(dev, obj, all); ++ } ++ ++ /* sleep */ ++ ++ ret = ntsync_schedule(q, &args); ++ ++ /* and finally, unqueue */ ++ ++ for (i = 0; i < total_count; i++) { ++ struct ntsync_q_entry *entry = &q->entries[i]; ++ struct ntsync_obj *obj = entry->obj; ++ ++ all = ntsync_lock_obj(dev, obj); ++ list_del(&entry->node); ++ ntsync_unlock_obj(dev, obj, all); ++ ++ put_obj(obj); ++ } ++ ++ signaled = atomic_read(&q->signaled); ++ if (signaled != -1) { ++ struct ntsync_wait_args __user *user_args = argp; ++ ++ /* even if we caught a signal, we need to communicate success */ ++ ret = q->ownerdead ? -EOWNERDEAD : 0; ++ ++ if (put_user(signaled, &user_args->index)) ++ ret = -EFAULT; ++ } else if (!ret) { ++ ret = -ETIMEDOUT; ++ } ++ ++ kfree(q); ++ return ret; ++} ++ ++static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp) ++{ ++ struct ntsync_wait_args args; ++ struct ntsync_q *q; ++ int signaled; ++ __u32 i; ++ int ret; ++ ++ if (copy_from_user(&args, argp, sizeof(args))) ++ return -EFAULT; ++ ++ ret = setup_wait(dev, &args, true, &q); ++ if (ret < 0) ++ return ret; ++ ++ /* queue ourselves */ ++ ++ mutex_lock(&dev->wait_all_lock); ++ ++ for (i = 0; i < args.count; i++) { ++ struct ntsync_q_entry *entry = &q->entries[i]; ++ struct ntsync_obj *obj = entry->obj; ++ ++ atomic_inc(&obj->all_hint); ++ ++ /* ++ * obj->all_waiters is protected by dev->wait_all_lock rather ++ * than obj->lock, so there is no need to acquire obj->lock ++ * here. ++ */ ++ list_add_tail(&entry->node, &obj->all_waiters); ++ } ++ if (args.alert) { ++ struct ntsync_q_entry *entry = &q->entries[args.count]; ++ struct ntsync_obj *obj = entry->obj; ++ ++ dev_lock_obj(dev, obj); ++ list_add_tail(&entry->node, &obj->any_waiters); ++ dev_unlock_obj(dev, obj); ++ } ++ ++ /* check if we are already signaled */ ++ ++ try_wake_all(dev, q, NULL); ++ ++ mutex_unlock(&dev->wait_all_lock); ++ ++ /* ++ * Check if the alert event is signaled, making sure to do so only ++ * after checking if the other objects are signaled. ++ */ ++ ++ if (args.alert) { ++ struct ntsync_obj *obj = q->entries[args.count].obj; ++ ++ if (atomic_read(&q->signaled) == -1) { ++ bool all = ntsync_lock_obj(dev, obj); ++ try_wake_any_obj(obj); ++ ntsync_unlock_obj(dev, obj, all); ++ } ++ } ++ ++ /* sleep */ ++ ++ ret = ntsync_schedule(q, &args); ++ ++ /* and finally, unqueue */ ++ ++ mutex_lock(&dev->wait_all_lock); ++ ++ for (i = 0; i < args.count; i++) { ++ struct ntsync_q_entry *entry = &q->entries[i]; ++ struct ntsync_obj *obj = entry->obj; ++ ++ /* ++ * obj->all_waiters is protected by dev->wait_all_lock rather ++ * than obj->lock, so there is no need to acquire it here. ++ */ ++ list_del(&entry->node); ++ ++ atomic_dec(&obj->all_hint); ++ ++ put_obj(obj); ++ } ++ ++ mutex_unlock(&dev->wait_all_lock); ++ ++ if (args.alert) { ++ struct ntsync_q_entry *entry = &q->entries[args.count]; ++ struct ntsync_obj *obj = entry->obj; ++ bool all; ++ ++ all = ntsync_lock_obj(dev, obj); ++ list_del(&entry->node); ++ ntsync_unlock_obj(dev, obj, all); ++ ++ put_obj(obj); ++ } ++ ++ signaled = atomic_read(&q->signaled); ++ if (signaled != -1) { ++ struct ntsync_wait_args __user *user_args = argp; ++ ++ /* even if we caught a signal, we need to communicate success */ ++ ret = q->ownerdead ? -EOWNERDEAD : 0; ++ ++ if (put_user(signaled, &user_args->index)) ++ ret = -EFAULT; ++ } else if (!ret) { ++ ret = -ETIMEDOUT; ++ } ++ ++ kfree(q); ++ return ret; ++} ++ + static int ntsync_char_open(struct inode *inode, struct file *file) + { + struct ntsync_device *dev; +@@ -199,6 +1172,8 @@ static int ntsync_char_open(struct inode *inode, struct file *file) + if (!dev) + return -ENOMEM; + ++ mutex_init(&dev->wait_all_lock); ++ + file->private_data = dev; + dev->file = file; + return nonseekable_open(inode, file); +@@ -220,8 +1195,16 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd, + void __user *argp = (void __user *)parm; + + switch (cmd) { ++ case NTSYNC_IOC_CREATE_EVENT: ++ return ntsync_create_event(dev, argp); ++ case NTSYNC_IOC_CREATE_MUTEX: ++ return ntsync_create_mutex(dev, argp); + case NTSYNC_IOC_CREATE_SEM: + return ntsync_create_sem(dev, argp); ++ case NTSYNC_IOC_WAIT_ALL: ++ return ntsync_wait_all(dev, argp); ++ case NTSYNC_IOC_WAIT_ANY: ++ return ntsync_wait_any(dev, argp); + default: + return -ENOIOCTLCMD; + } +diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h +index dcfa38fdc93c..4a8095a3fc34 100644 +--- a/include/uapi/linux/ntsync.h ++++ b/include/uapi/linux/ntsync.h +@@ -16,8 +16,47 @@ struct ntsync_sem_args { + __u32 max; + }; + ++struct ntsync_mutex_args { ++ __u32 mutex; ++ __u32 owner; ++ __u32 count; ++}; ++ ++struct ntsync_event_args { ++ __u32 event; ++ __u32 manual; ++ __u32 signaled; ++}; ++ ++#define NTSYNC_WAIT_REALTIME 0x1 ++ ++struct ntsync_wait_args { ++ __u64 timeout; ++ __u64 objs; ++ __u32 count; ++ __u32 index; ++ __u32 flags; ++ __u32 owner; ++ __u32 alert; ++ __u32 pad; ++}; ++ ++#define NTSYNC_MAX_WAIT_COUNT 64 ++ + #define NTSYNC_IOC_CREATE_SEM _IOWR('N', 0x80, struct ntsync_sem_args) ++#define NTSYNC_IOC_WAIT_ANY _IOWR('N', 0x82, struct ntsync_wait_args) ++#define NTSYNC_IOC_WAIT_ALL _IOWR('N', 0x83, struct ntsync_wait_args) ++#define NTSYNC_IOC_CREATE_MUTEX _IOWR('N', 0x84, struct ntsync_sem_args) ++#define NTSYNC_IOC_CREATE_EVENT _IOWR('N', 0x87, struct ntsync_event_args) + + #define NTSYNC_IOC_SEM_POST _IOWR('N', 0x81, __u32) ++#define NTSYNC_IOC_MUTEX_UNLOCK _IOWR('N', 0x85, struct ntsync_mutex_args) ++#define NTSYNC_IOC_MUTEX_KILL _IOW ('N', 0x86, __u32) ++#define NTSYNC_IOC_EVENT_SET _IOR ('N', 0x88, __u32) ++#define NTSYNC_IOC_EVENT_RESET _IOR ('N', 0x89, __u32) ++#define NTSYNC_IOC_EVENT_PULSE _IOR ('N', 0x8a, __u32) ++#define NTSYNC_IOC_SEM_READ _IOR ('N', 0x8b, struct ntsync_sem_args) ++#define NTSYNC_IOC_MUTEX_READ _IOR ('N', 0x8c, struct ntsync_mutex_args) ++#define NTSYNC_IOC_EVENT_READ _IOR ('N', 0x8d, struct ntsync_event_args) + + #endif +diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile +index 9039f3709aff..d5aeaa8fe3ca 100644 +--- a/tools/testing/selftests/Makefile ++++ b/tools/testing/selftests/Makefile +@@ -16,6 +16,7 @@ TARGETS += damon + TARGETS += devices + TARGETS += dmabuf-heaps + TARGETS += drivers/dma-buf ++TARGETS += drivers/ntsync + TARGETS += drivers/s390x/uvdevice + TARGETS += drivers/net + TARGETS += drivers/net/bonding +diff --git a/tools/testing/selftests/drivers/ntsync/.gitignore b/tools/testing/selftests/drivers/ntsync/.gitignore +new file mode 100644 +index 000000000000..848573a3d3ea +--- /dev/null ++++ b/tools/testing/selftests/drivers/ntsync/.gitignore +@@ -0,0 +1 @@ ++ntsync +diff --git a/tools/testing/selftests/drivers/ntsync/Makefile b/tools/testing/selftests/drivers/ntsync/Makefile +new file mode 100644 +index 000000000000..dbf2b055c0b2 +--- /dev/null ++++ b/tools/testing/selftests/drivers/ntsync/Makefile +@@ -0,0 +1,7 @@ ++# SPDX-LICENSE-IDENTIFIER: GPL-2.0-only ++TEST_GEN_PROGS := ntsync ++ ++CFLAGS += $(KHDR_INCLUDES) ++LDLIBS += -lpthread ++ ++include ../../lib.mk +diff --git a/tools/testing/selftests/drivers/ntsync/config b/tools/testing/selftests/drivers/ntsync/config +new file mode 100644 +index 000000000000..60539c826d06 +--- /dev/null ++++ b/tools/testing/selftests/drivers/ntsync/config +@@ -0,0 +1 @@ ++CONFIG_WINESYNC=y +diff --git a/tools/testing/selftests/drivers/ntsync/ntsync.c b/tools/testing/selftests/drivers/ntsync/ntsync.c +new file mode 100644 +index 000000000000..5fa2c9a0768c +--- /dev/null ++++ b/tools/testing/selftests/drivers/ntsync/ntsync.c +@@ -0,0 +1,1407 @@ ++// SPDX-License-Identifier: GPL-2.0-or-later ++/* ++ * Various unit tests for the "ntsync" synchronization primitive driver. ++ * ++ * Copyright (C) 2021-2022 Elizabeth Figura ++ */ ++ ++#define _GNU_SOURCE ++#include ++#include ++#include ++#include ++#include ++#include ++#include "../../kselftest_harness.h" ++ ++static int read_sem_state(int sem, __u32 *count, __u32 *max) ++{ ++ struct ntsync_sem_args args; ++ int ret; ++ ++ memset(&args, 0xcc, sizeof(args)); ++ ret = ioctl(sem, NTSYNC_IOC_SEM_READ, &args); ++ *count = args.count; ++ *max = args.max; ++ return ret; ++} ++ ++#define check_sem_state(sem, count, max) \ ++ ({ \ ++ __u32 __count, __max; \ ++ int ret = read_sem_state((sem), &__count, &__max); \ ++ EXPECT_EQ(0, ret); \ ++ EXPECT_EQ((count), __count); \ ++ EXPECT_EQ((max), __max); \ ++ }) ++ ++static int post_sem(int sem, __u32 *count) ++{ ++ return ioctl(sem, NTSYNC_IOC_SEM_POST, count); ++} ++ ++static int read_mutex_state(int mutex, __u32 *count, __u32 *owner) ++{ ++ struct ntsync_mutex_args args; ++ int ret; ++ ++ memset(&args, 0xcc, sizeof(args)); ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &args); ++ *count = args.count; ++ *owner = args.owner; ++ return ret; ++} ++ ++#define check_mutex_state(mutex, count, owner) \ ++ ({ \ ++ __u32 __count, __owner; \ ++ int ret = read_mutex_state((mutex), &__count, &__owner); \ ++ EXPECT_EQ(0, ret); \ ++ EXPECT_EQ((count), __count); \ ++ EXPECT_EQ((owner), __owner); \ ++ }) ++ ++static int unlock_mutex(int mutex, __u32 owner, __u32 *count) ++{ ++ struct ntsync_mutex_args args; ++ int ret; ++ ++ args.owner = owner; ++ args.count = 0xdeadbeef; ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_UNLOCK, &args); ++ *count = args.count; ++ return ret; ++} ++ ++static int read_event_state(int event, __u32 *signaled, __u32 *manual) ++{ ++ struct ntsync_event_args args; ++ int ret; ++ ++ memset(&args, 0xcc, sizeof(args)); ++ ret = ioctl(event, NTSYNC_IOC_EVENT_READ, &args); ++ *signaled = args.signaled; ++ *manual = args.manual; ++ return ret; ++} ++ ++#define check_event_state(event, signaled, manual) \ ++ ({ \ ++ __u32 __signaled, __manual; \ ++ int ret = read_event_state((event), &__signaled, &__manual); \ ++ EXPECT_EQ(0, ret); \ ++ EXPECT_EQ((signaled), __signaled); \ ++ EXPECT_EQ((manual), __manual); \ ++ }) ++ ++static int wait_objs(int fd, unsigned long request, __u32 count, ++ const int *objs, __u32 owner, int alert, __u32 *index) ++{ ++ struct ntsync_wait_args args = {0}; ++ struct timespec timeout; ++ int ret; ++ ++ clock_gettime(CLOCK_MONOTONIC, &timeout); ++ ++ args.timeout = timeout.tv_sec * 1000000000 + timeout.tv_nsec; ++ args.count = count; ++ args.objs = (uintptr_t)objs; ++ args.owner = owner; ++ args.index = 0xdeadbeef; ++ args.alert = alert; ++ ret = ioctl(fd, request, &args); ++ *index = args.index; ++ return ret; ++} ++ ++static int wait_any(int fd, __u32 count, const int *objs, __u32 owner, __u32 *index) ++{ ++ return wait_objs(fd, NTSYNC_IOC_WAIT_ANY, count, objs, owner, 0, index); ++} ++ ++static int wait_all(int fd, __u32 count, const int *objs, __u32 owner, __u32 *index) ++{ ++ return wait_objs(fd, NTSYNC_IOC_WAIT_ALL, count, objs, owner, 0, index); ++} ++ ++static int wait_any_alert(int fd, __u32 count, const int *objs, ++ __u32 owner, int alert, __u32 *index) ++{ ++ return wait_objs(fd, NTSYNC_IOC_WAIT_ANY, ++ count, objs, owner, alert, index); ++} ++ ++static int wait_all_alert(int fd, __u32 count, const int *objs, ++ __u32 owner, int alert, __u32 *index) ++{ ++ return wait_objs(fd, NTSYNC_IOC_WAIT_ALL, ++ count, objs, owner, alert, index); ++} ++ ++TEST(semaphore_state) ++{ ++ struct ntsync_sem_args sem_args; ++ struct timespec timeout; ++ __u32 count, index; ++ int fd, ret, sem; ++ ++ clock_gettime(CLOCK_MONOTONIC, &timeout); ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 3; ++ sem_args.max = 2; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ sem_args.count = 2; ++ sem_args.max = 2; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ sem = sem_args.sem; ++ check_sem_state(sem, 2, 2); ++ ++ count = 0; ++ ret = post_sem(sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, count); ++ check_sem_state(sem, 2, 2); ++ ++ count = 1; ++ ret = post_sem(sem, &count); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOVERFLOW, errno); ++ check_sem_state(sem, 2, 2); ++ ++ ret = wait_any(fd, 1, &sem, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem, 1, 2); ++ ++ ret = wait_any(fd, 1, &sem, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem, 0, 2); ++ ++ ret = wait_any(fd, 1, &sem, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ count = 3; ++ ret = post_sem(sem, &count); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOVERFLOW, errno); ++ check_sem_state(sem, 0, 2); ++ ++ count = 2; ++ ret = post_sem(sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ check_sem_state(sem, 2, 2); ++ ++ ret = wait_any(fd, 1, &sem, 123, &index); ++ EXPECT_EQ(0, ret); ++ ret = wait_any(fd, 1, &sem, 123, &index); ++ EXPECT_EQ(0, ret); ++ ++ count = 1; ++ ret = post_sem(sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ check_sem_state(sem, 1, 2); ++ ++ count = ~0u; ++ ret = post_sem(sem, &count); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOVERFLOW, errno); ++ check_sem_state(sem, 1, 2); ++ ++ close(sem); ++ ++ close(fd); ++} ++ ++TEST(mutex_state) ++{ ++ struct ntsync_mutex_args mutex_args; ++ __u32 owner, count, index; ++ struct timespec timeout; ++ int fd, ret, mutex; ++ ++ clock_gettime(CLOCK_MONOTONIC, &timeout); ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ mutex_args.owner = 123; ++ mutex_args.count = 0; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ mutex_args.owner = 0; ++ mutex_args.count = 2; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ mutex_args.owner = 123; ++ mutex_args.count = 2; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ mutex = mutex_args.mutex; ++ check_mutex_state(mutex, 2, 123); ++ ++ ret = unlock_mutex(mutex, 0, &count); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ ret = unlock_mutex(mutex, 456, &count); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EPERM, errno); ++ check_mutex_state(mutex, 2, 123); ++ ++ ret = unlock_mutex(mutex, 123, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, count); ++ check_mutex_state(mutex, 1, 123); ++ ++ ret = unlock_mutex(mutex, 123, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, count); ++ check_mutex_state(mutex, 0, 0); ++ ++ ret = unlock_mutex(mutex, 123, &count); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EPERM, errno); ++ ++ ret = wait_any(fd, 1, &mutex, 456, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_mutex_state(mutex, 1, 456); ++ ++ ret = wait_any(fd, 1, &mutex, 456, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_mutex_state(mutex, 2, 456); ++ ++ ret = unlock_mutex(mutex, 456, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, count); ++ check_mutex_state(mutex, 1, 456); ++ ++ ret = wait_any(fd, 1, &mutex, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ owner = 0; ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ owner = 123; ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EPERM, errno); ++ check_mutex_state(mutex, 1, 456); ++ ++ owner = 456; ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); ++ EXPECT_EQ(0, ret); ++ ++ memset(&mutex_args, 0xcc, sizeof(mutex_args)); ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &mutex_args); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ EXPECT_EQ(0, mutex_args.count); ++ EXPECT_EQ(0, mutex_args.owner); ++ ++ memset(&mutex_args, 0xcc, sizeof(mutex_args)); ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &mutex_args); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ EXPECT_EQ(0, mutex_args.count); ++ EXPECT_EQ(0, mutex_args.owner); ++ ++ ret = wait_any(fd, 1, &mutex, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ EXPECT_EQ(0, index); ++ check_mutex_state(mutex, 1, 123); ++ ++ owner = 123; ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); ++ EXPECT_EQ(0, ret); ++ ++ memset(&mutex_args, 0xcc, sizeof(mutex_args)); ++ ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &mutex_args); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ EXPECT_EQ(0, mutex_args.count); ++ EXPECT_EQ(0, mutex_args.owner); ++ ++ ret = wait_any(fd, 1, &mutex, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ EXPECT_EQ(0, index); ++ check_mutex_state(mutex, 1, 123); ++ ++ close(mutex); ++ ++ mutex_args.owner = 0; ++ mutex_args.count = 0; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ mutex = mutex_args.mutex; ++ check_mutex_state(mutex, 0, 0); ++ ++ ret = wait_any(fd, 1, &mutex, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_mutex_state(mutex, 1, 123); ++ ++ close(mutex); ++ ++ mutex_args.owner = 123; ++ mutex_args.count = ~0u; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ mutex = mutex_args.mutex; ++ check_mutex_state(mutex, ~0u, 123); ++ ++ ret = wait_any(fd, 1, &mutex, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ close(mutex); ++ ++ close(fd); ++} ++ ++TEST(manual_event_state) ++{ ++ struct ntsync_event_args event_args; ++ __u32 index, signaled; ++ int fd, event, ret; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ event_args.manual = 1; ++ event_args.signaled = 0; ++ event_args.event = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, event_args.event); ++ event = event_args.event; ++ check_event_state(event, 0, 1); ++ ++ signaled = 0xdeadbeef; ++ ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event, 1, 1); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ check_event_state(event, 1, 1); ++ ++ ret = wait_any(fd, 1, &event, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_event_state(event, 1, 1); ++ ++ signaled = 0xdeadbeef; ++ ret = ioctl(event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ check_event_state(event, 0, 1); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event, 0, 1); ++ ++ ret = wait_any(fd, 1, &event, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ check_event_state(event, 0, 1); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event, 0, 1); ++ ++ close(event); ++ ++ close(fd); ++} ++ ++TEST(auto_event_state) ++{ ++ struct ntsync_event_args event_args; ++ __u32 index, signaled; ++ int fd, event, ret; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ event_args.manual = 0; ++ event_args.signaled = 1; ++ event_args.event = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, event_args.event); ++ event = event_args.event; ++ ++ check_event_state(event, 1, 0); ++ ++ signaled = 0xdeadbeef; ++ ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ check_event_state(event, 1, 0); ++ ++ ret = wait_any(fd, 1, &event, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_event_state(event, 0, 0); ++ ++ signaled = 0xdeadbeef; ++ ret = ioctl(event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event, 0, 0); ++ ++ ret = wait_any(fd, 1, &event, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ check_event_state(event, 0, 0); ++ ++ ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event, 0, 0); ++ ++ close(event); ++ ++ close(fd); ++} ++ ++TEST(test_wait_any) ++{ ++ int objs[NTSYNC_MAX_WAIT_COUNT + 1], fd, ret; ++ struct ntsync_mutex_args mutex_args = {0}; ++ struct ntsync_sem_args sem_args = {0}; ++ __u32 owner, index, count, i; ++ struct timespec timeout; ++ ++ clock_gettime(CLOCK_MONOTONIC, &timeout); ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 2; ++ sem_args.max = 3; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ ++ mutex_args.owner = 0; ++ mutex_args.count = 0; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ ++ objs[0] = sem_args.sem; ++ objs[1] = mutex_args.mutex; ++ ++ ret = wait_any(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 1, 3); ++ check_mutex_state(mutex_args.mutex, 0, 0); ++ ++ ret = wait_any(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_mutex_state(mutex_args.mutex, 0, 0); ++ ++ ret = wait_any(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, index); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_mutex_state(mutex_args.mutex, 1, 123); ++ ++ count = 1; ++ ret = post_sem(sem_args.sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ ++ ret = wait_any(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_mutex_state(mutex_args.mutex, 1, 123); ++ ++ ret = wait_any(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, index); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_mutex_state(mutex_args.mutex, 2, 123); ++ ++ ret = wait_any(fd, 2, objs, 456, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ owner = 123; ++ ret = ioctl(mutex_args.mutex, NTSYNC_IOC_MUTEX_KILL, &owner); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_any(fd, 2, objs, 456, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ EXPECT_EQ(1, index); ++ ++ ret = wait_any(fd, 2, objs, 456, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, index); ++ ++ /* test waiting on the same object twice */ ++ count = 2; ++ ret = post_sem(sem_args.sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ ++ objs[0] = objs[1] = sem_args.sem; ++ ret = wait_any(fd, 2, objs, 456, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 1, 3); ++ ++ ret = wait_any(fd, 0, NULL, 456, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ for (i = 0; i < NTSYNC_MAX_WAIT_COUNT + 1; ++i) ++ objs[i] = sem_args.sem; ++ ++ ret = wait_any(fd, NTSYNC_MAX_WAIT_COUNT, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ ret = wait_any(fd, NTSYNC_MAX_WAIT_COUNT + 1, objs, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ ret = wait_any(fd, -1, objs, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ close(sem_args.sem); ++ close(mutex_args.mutex); ++ ++ close(fd); ++} ++ ++TEST(test_wait_all) ++{ ++ struct ntsync_event_args event_args = {0}; ++ struct ntsync_mutex_args mutex_args = {0}; ++ struct ntsync_sem_args sem_args = {0}; ++ __u32 owner, index, count; ++ int objs[2], fd, ret; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 2; ++ sem_args.max = 3; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ ++ mutex_args.owner = 0; ++ mutex_args.count = 0; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ ++ event_args.manual = true; ++ event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ objs[0] = sem_args.sem; ++ objs[1] = mutex_args.mutex; ++ ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 1, 3); ++ check_mutex_state(mutex_args.mutex, 1, 123); ++ ++ ret = wait_all(fd, 2, objs, 456, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ check_sem_state(sem_args.sem, 1, 3); ++ check_mutex_state(mutex_args.mutex, 1, 123); ++ ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_mutex_state(mutex_args.mutex, 2, 123); ++ ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_mutex_state(mutex_args.mutex, 2, 123); ++ ++ count = 3; ++ ret = post_sem(sem_args.sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 2, 3); ++ check_mutex_state(mutex_args.mutex, 3, 123); ++ ++ owner = 123; ++ ret = ioctl(mutex_args.mutex, NTSYNC_IOC_MUTEX_KILL, &owner); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EOWNERDEAD, errno); ++ check_sem_state(sem_args.sem, 1, 3); ++ check_mutex_state(mutex_args.mutex, 1, 123); ++ ++ objs[0] = sem_args.sem; ++ objs[1] = event_args.event; ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ check_sem_state(sem_args.sem, 0, 3); ++ check_event_state(event_args.event, 1, 1); ++ ++ /* test waiting on the same object twice */ ++ objs[0] = objs[1] = sem_args.sem; ++ ret = wait_all(fd, 2, objs, 123, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(EINVAL, errno); ++ ++ close(sem_args.sem); ++ close(mutex_args.mutex); ++ close(event_args.event); ++ ++ close(fd); ++} ++ ++struct wake_args { ++ int fd; ++ int obj; ++}; ++ ++struct wait_args { ++ int fd; ++ unsigned long request; ++ struct ntsync_wait_args *args; ++ int ret; ++ int err; ++}; ++ ++static void *wait_thread(void *arg) ++{ ++ struct wait_args *args = arg; ++ ++ args->ret = ioctl(args->fd, args->request, args->args); ++ args->err = errno; ++ return NULL; ++} ++ ++static __u64 get_abs_timeout(unsigned int ms) ++{ ++ struct timespec timeout; ++ clock_gettime(CLOCK_MONOTONIC, &timeout); ++ return (timeout.tv_sec * 1000000000) + timeout.tv_nsec + (ms * 1000000); ++} ++ ++static int wait_for_thread(pthread_t thread, unsigned int ms) ++{ ++ struct timespec timeout; ++ ++ clock_gettime(CLOCK_REALTIME, &timeout); ++ timeout.tv_nsec += ms * 1000000; ++ timeout.tv_sec += (timeout.tv_nsec / 1000000000); ++ timeout.tv_nsec %= 1000000000; ++ return pthread_timedjoin_np(thread, NULL, &timeout); ++} ++ ++TEST(wake_any) ++{ ++ struct ntsync_event_args event_args = {0}; ++ struct ntsync_mutex_args mutex_args = {0}; ++ struct ntsync_wait_args wait_args = {0}; ++ struct ntsync_sem_args sem_args = {0}; ++ struct wait_args thread_args; ++ __u32 count, index, signaled; ++ int objs[2], fd, ret; ++ pthread_t thread; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 0; ++ sem_args.max = 3; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ ++ mutex_args.owner = 123; ++ mutex_args.count = 1; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ ++ objs[0] = sem_args.sem; ++ objs[1] = mutex_args.mutex; ++ ++ /* test waking the semaphore */ ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ wait_args.objs = (uintptr_t)objs; ++ wait_args.count = 2; ++ wait_args.owner = 456; ++ wait_args.index = 0xdeadbeef; ++ thread_args.fd = fd; ++ thread_args.args = &wait_args; ++ thread_args.request = NTSYNC_IOC_WAIT_ANY; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ count = 1; ++ ret = post_sem(sem_args.sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ check_sem_state(sem_args.sem, 0, 3); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(0, wait_args.index); ++ ++ /* test waking the mutex */ ++ ++ /* first grab it again for owner 123 */ ++ ret = wait_any(fd, 1, &mutex_args.mutex, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ wait_args.owner = 456; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = unlock_mutex(mutex_args.mutex, 123, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, count); ++ ++ ret = pthread_tryjoin_np(thread, NULL); ++ EXPECT_EQ(EBUSY, ret); ++ ++ ret = unlock_mutex(mutex_args.mutex, 123, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, mutex_args.count); ++ check_mutex_state(mutex_args.mutex, 1, 456); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(1, wait_args.index); ++ ++ /* test waking events */ ++ ++ event_args.manual = false; ++ event_args.signaled = false; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ objs[1] = event_args.event; ++ wait_args.timeout = get_abs_timeout(1000); ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event_args.event, 0, 0); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(1, wait_args.index); ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_PULSE, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event_args.event, 0, 0); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(1, wait_args.index); ++ ++ close(event_args.event); ++ ++ event_args.manual = true; ++ event_args.signaled = false; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ objs[1] = event_args.event; ++ wait_args.timeout = get_abs_timeout(1000); ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event_args.event, 1, 1); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(1, wait_args.index); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_PULSE, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ check_event_state(event_args.event, 0, 1); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(1, wait_args.index); ++ ++ close(event_args.event); ++ ++ /* delete an object while it's being waited on */ ++ ++ wait_args.timeout = get_abs_timeout(200); ++ wait_args.owner = 123; ++ objs[1] = mutex_args.mutex; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ close(sem_args.sem); ++ close(mutex_args.mutex); ++ ++ ret = wait_for_thread(thread, 200); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(-1, thread_args.ret); ++ EXPECT_EQ(ETIMEDOUT, thread_args.err); ++ ++ close(fd); ++} ++ ++TEST(wake_all) ++{ ++ struct ntsync_event_args manual_event_args = {0}; ++ struct ntsync_event_args auto_event_args = {0}; ++ struct ntsync_mutex_args mutex_args = {0}; ++ struct ntsync_wait_args wait_args = {0}; ++ struct ntsync_sem_args sem_args = {0}; ++ struct wait_args thread_args; ++ __u32 count, index, signaled; ++ int objs[4], fd, ret; ++ pthread_t thread; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 0; ++ sem_args.max = 3; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ ++ mutex_args.owner = 123; ++ mutex_args.count = 1; ++ mutex_args.mutex = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, mutex_args.mutex); ++ ++ manual_event_args.manual = true; ++ manual_event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &manual_event_args); ++ EXPECT_EQ(0, ret); ++ ++ auto_event_args.manual = false; ++ auto_event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &auto_event_args); ++ EXPECT_EQ(0, ret); ++ ++ objs[0] = sem_args.sem; ++ objs[1] = mutex_args.mutex; ++ objs[2] = manual_event_args.event; ++ objs[3] = auto_event_args.event; ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ wait_args.objs = (uintptr_t)objs; ++ wait_args.count = 4; ++ wait_args.owner = 456; ++ thread_args.fd = fd; ++ thread_args.args = &wait_args; ++ thread_args.request = NTSYNC_IOC_WAIT_ALL; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ count = 1; ++ ret = post_sem(sem_args.sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ ++ ret = pthread_tryjoin_np(thread, NULL); ++ EXPECT_EQ(EBUSY, ret); ++ ++ check_sem_state(sem_args.sem, 1, 3); ++ ++ ret = wait_any(fd, 1, &sem_args.sem, 123, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ ret = unlock_mutex(mutex_args.mutex, 123, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, count); ++ ++ ret = pthread_tryjoin_np(thread, NULL); ++ EXPECT_EQ(EBUSY, ret); ++ ++ check_mutex_state(mutex_args.mutex, 0, 0); ++ ++ ret = ioctl(manual_event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ ++ count = 2; ++ ret = post_sem(sem_args.sem, &count); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, count); ++ check_sem_state(sem_args.sem, 2, 3); ++ ++ ret = ioctl(auto_event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, signaled); ++ ++ ret = ioctl(manual_event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ ++ ret = ioctl(auto_event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, signaled); ++ ++ check_sem_state(sem_args.sem, 1, 3); ++ check_mutex_state(mutex_args.mutex, 1, 456); ++ check_event_state(manual_event_args.event, 1, 1); ++ check_event_state(auto_event_args.event, 0, 0); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ ++ /* delete an object while it's being waited on */ ++ ++ wait_args.timeout = get_abs_timeout(200); ++ wait_args.owner = 123; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ close(sem_args.sem); ++ close(mutex_args.mutex); ++ close(manual_event_args.event); ++ close(auto_event_args.event); ++ ++ ret = wait_for_thread(thread, 200); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(-1, thread_args.ret); ++ EXPECT_EQ(ETIMEDOUT, thread_args.err); ++ ++ close(fd); ++} ++ ++TEST(alert_any) ++{ ++ struct ntsync_event_args event_args = {0}; ++ struct ntsync_wait_args wait_args = {0}; ++ struct ntsync_sem_args sem_args = {0}; ++ __u32 index, count, signaled; ++ struct wait_args thread_args; ++ int objs[2], fd, ret; ++ pthread_t thread; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 0; ++ sem_args.max = 2; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ objs[0] = sem_args.sem; ++ ++ sem_args.count = 1; ++ sem_args.max = 2; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ objs[1] = sem_args.sem; ++ ++ event_args.manual = true; ++ event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_any_alert(fd, 0, NULL, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_any_alert(fd, 0, NULL, 123, event_args.event, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(1, index); ++ ++ ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, index); ++ ++ /* test wakeup via alert */ ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ wait_args.objs = (uintptr_t)objs; ++ wait_args.count = 2; ++ wait_args.owner = 123; ++ wait_args.index = 0xdeadbeef; ++ wait_args.alert = event_args.event; ++ thread_args.fd = fd; ++ thread_args.args = &wait_args; ++ thread_args.request = NTSYNC_IOC_WAIT_ANY; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(2, wait_args.index); ++ ++ close(event_args.event); ++ ++ /* test with an auto-reset event */ ++ ++ event_args.manual = false; ++ event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ count = 1; ++ ret = post_sem(objs[0], &count); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, index); ++ ++ ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ close(event_args.event); ++ ++ close(objs[0]); ++ close(objs[1]); ++ ++ close(fd); ++} ++ ++TEST(alert_all) ++{ ++ struct ntsync_event_args event_args = {0}; ++ struct ntsync_wait_args wait_args = {0}; ++ struct ntsync_sem_args sem_args = {0}; ++ struct wait_args thread_args; ++ __u32 index, count, signaled; ++ int objs[2], fd, ret; ++ pthread_t thread; ++ ++ fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, fd); ++ ++ sem_args.count = 2; ++ sem_args.max = 2; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ objs[0] = sem_args.sem; ++ ++ sem_args.count = 1; ++ sem_args.max = 2; ++ sem_args.sem = 0xdeadbeef; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); ++ EXPECT_EQ(0, ret); ++ EXPECT_NE(0xdeadbeef, sem_args.sem); ++ objs[1] = sem_args.sem; ++ ++ event_args.manual = true; ++ event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, index); ++ ++ /* test wakeup via alert */ ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ wait_args.timeout = get_abs_timeout(1000); ++ wait_args.objs = (uintptr_t)objs; ++ wait_args.count = 2; ++ wait_args.owner = 123; ++ wait_args.index = 0xdeadbeef; ++ wait_args.alert = event_args.event; ++ thread_args.fd = fd; ++ thread_args.args = &wait_args; ++ thread_args.request = NTSYNC_IOC_WAIT_ALL; ++ ret = pthread_create(&thread, NULL, wait_thread, &thread_args); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(ETIMEDOUT, ret); ++ ++ ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_for_thread(thread, 100); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, thread_args.ret); ++ EXPECT_EQ(2, wait_args.index); ++ ++ close(event_args.event); ++ ++ /* test with an auto-reset event */ ++ ++ event_args.manual = false; ++ event_args.signaled = true; ++ ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ ++ count = 2; ++ ret = post_sem(objs[1], &count); ++ EXPECT_EQ(0, ret); ++ ++ ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(0, index); ++ ++ ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(0, ret); ++ EXPECT_EQ(2, index); ++ ++ ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); ++ EXPECT_EQ(-1, ret); ++ EXPECT_EQ(ETIMEDOUT, errno); ++ ++ close(event_args.event); ++ ++ close(objs[0]); ++ close(objs[1]); ++ ++ close(fd); ++} ++ ++#define STRESS_LOOPS 10000 ++#define STRESS_THREADS 4 ++ ++static unsigned int stress_counter; ++static int stress_device, stress_start_event, stress_mutex; ++ ++static void *stress_thread(void *arg) ++{ ++ struct ntsync_wait_args wait_args = {0}; ++ __u32 index, count, i; ++ int ret; ++ ++ wait_args.timeout = UINT64_MAX; ++ wait_args.count = 1; ++ wait_args.objs = (uintptr_t)&stress_start_event; ++ wait_args.owner = gettid(); ++ wait_args.index = 0xdeadbeef; ++ ++ ioctl(stress_device, NTSYNC_IOC_WAIT_ANY, &wait_args); ++ ++ wait_args.objs = (uintptr_t)&stress_mutex; ++ ++ for (i = 0; i < STRESS_LOOPS; ++i) { ++ ioctl(stress_device, NTSYNC_IOC_WAIT_ANY, &wait_args); ++ ++ ++stress_counter; ++ ++ unlock_mutex(stress_mutex, wait_args.owner, &count); ++ } ++ ++ return NULL; ++} ++ ++TEST(stress_wait) ++{ ++ struct ntsync_event_args event_args; ++ struct ntsync_mutex_args mutex_args; ++ pthread_t threads[STRESS_THREADS]; ++ __u32 signaled, i; ++ int ret; ++ ++ stress_device = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); ++ ASSERT_LE(0, stress_device); ++ ++ mutex_args.owner = 0; ++ mutex_args.count = 0; ++ ret = ioctl(stress_device, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); ++ EXPECT_EQ(0, ret); ++ stress_mutex = mutex_args.mutex; ++ ++ event_args.manual = 1; ++ event_args.signaled = 0; ++ ret = ioctl(stress_device, NTSYNC_IOC_CREATE_EVENT, &event_args); ++ EXPECT_EQ(0, ret); ++ stress_start_event = event_args.event; ++ ++ for (i = 0; i < STRESS_THREADS; ++i) ++ pthread_create(&threads[i], NULL, stress_thread, NULL); ++ ++ ret = ioctl(stress_start_event, NTSYNC_IOC_EVENT_SET, &signaled); ++ EXPECT_EQ(0, ret); ++ ++ for (i = 0; i < STRESS_THREADS; ++i) { ++ ret = pthread_join(threads[i], NULL); ++ EXPECT_EQ(0, ret); ++ } ++ ++ EXPECT_EQ(STRESS_LOOPS * STRESS_THREADS, stress_counter); ++ ++ close(stress_start_event); ++ close(stress_mutex); ++ close(stress_device); ++} ++ ++TEST_HARNESS_MAIN +-- +2.46.0 + +From 2aee49acbdd1e24099f9458a8eaed2ffe8afc683 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:20:05 +0200 +Subject: [PATCH 10/12] perf-per-core + +Signed-off-by: Peter Jung +--- + Documentation/arch/x86/topology.rst | 4 + + arch/x86/events/rapl.c | 418 ++++++++++++++++++-------- + arch/x86/include/asm/processor.h | 1 + + arch/x86/include/asm/topology.h | 1 + + arch/x86/kernel/cpu/debugfs.c | 1 + + arch/x86/kernel/cpu/topology_common.c | 1 + + 6 files changed, 305 insertions(+), 121 deletions(-) + +diff --git a/Documentation/arch/x86/topology.rst b/Documentation/arch/x86/topology.rst +index 7352ab89a55a..c12837e61bda 100644 +--- a/Documentation/arch/x86/topology.rst ++++ b/Documentation/arch/x86/topology.rst +@@ -135,6 +135,10 @@ Thread-related topology information in the kernel: + The ID of the core to which a thread belongs. It is also printed in /proc/cpuinfo + "core_id." + ++ - topology_logical_core_id(); ++ ++ The logical core ID to which a thread belongs. ++ + + + System topology examples +diff --git a/arch/x86/events/rapl.c b/arch/x86/events/rapl.c +index 0c5e7a7c43ac..cd808b699ccc 100644 +--- a/arch/x86/events/rapl.c ++++ b/arch/x86/events/rapl.c +@@ -39,6 +39,10 @@ + * event: rapl_energy_psys + * perf code: 0x5 + * ++ * per_core counter: consumption of a single physical core ++ * event: rapl_energy_per_core (power_per_core PMU) ++ * perf code: 0x1 ++ * + * We manage those counters as free running (read-only). They may be + * use simultaneously by other tools, such as turbostat. + * +@@ -70,18 +74,25 @@ MODULE_LICENSE("GPL"); + /* + * RAPL energy status counters + */ +-enum perf_rapl_events { ++enum perf_rapl_pkg_events { + PERF_RAPL_PP0 = 0, /* all cores */ + PERF_RAPL_PKG, /* entire package */ + PERF_RAPL_RAM, /* DRAM */ + PERF_RAPL_PP1, /* gpu */ + PERF_RAPL_PSYS, /* psys */ + +- PERF_RAPL_MAX, +- NR_RAPL_DOMAINS = PERF_RAPL_MAX, ++ PERF_RAPL_PKG_EVENTS_MAX, ++ NR_RAPL_PKG_DOMAINS = PERF_RAPL_PKG_EVENTS_MAX, ++}; ++ ++enum perf_rapl_core_events { ++ PERF_RAPL_PER_CORE = 0, /* per-core */ ++ ++ PERF_RAPL_CORE_EVENTS_MAX, ++ NR_RAPL_CORE_DOMAINS = PERF_RAPL_CORE_EVENTS_MAX, + }; + +-static const char *const rapl_domain_names[NR_RAPL_DOMAINS] __initconst = { ++static const char *const rapl_pkg_domain_names[NR_RAPL_PKG_DOMAINS] __initconst = { + "pp0-core", + "package", + "dram", +@@ -89,6 +100,10 @@ static const char *const rapl_domain_names[NR_RAPL_DOMAINS] __initconst = { + "psys", + }; + ++static const char *const rapl_core_domain_names[NR_RAPL_CORE_DOMAINS] __initconst = { ++ "per-core", ++}; ++ + /* + * event code: LSB 8 bits, passed in attr->config + * any other bit is reserved +@@ -103,6 +118,10 @@ static struct perf_pmu_events_attr event_attr_##v = { \ + .event_str = str, \ + }; + ++#define rapl_pmu_is_pkg_scope() \ ++ (boot_cpu_data.x86_vendor == X86_VENDOR_AMD || \ ++ boot_cpu_data.x86_vendor == X86_VENDOR_HYGON) ++ + struct rapl_pmu { + raw_spinlock_t lock; + int n_active; +@@ -115,8 +134,9 @@ struct rapl_pmu { + + struct rapl_pmus { + struct pmu pmu; ++ cpumask_t cpumask; + unsigned int nr_rapl_pmu; +- struct rapl_pmu *pmus[] __counted_by(nr_rapl_pmu); ++ struct rapl_pmu *rapl_pmu[] __counted_by(nr_rapl_pmu); + }; + + enum rapl_unit_quirk { +@@ -126,29 +146,45 @@ enum rapl_unit_quirk { + }; + + struct rapl_model { +- struct perf_msr *rapl_msrs; +- unsigned long events; ++ struct perf_msr *rapl_pkg_msrs; ++ struct perf_msr *rapl_core_msrs; ++ unsigned long pkg_events; ++ unsigned long core_events; + unsigned int msr_power_unit; + enum rapl_unit_quirk unit_quirk; + }; + + /* 1/2^hw_unit Joule */ +-static int rapl_hw_unit[NR_RAPL_DOMAINS] __read_mostly; +-static struct rapl_pmus *rapl_pmus; +-static cpumask_t rapl_cpu_mask; +-static unsigned int rapl_cntr_mask; ++static int rapl_hw_unit[NR_RAPL_PKG_DOMAINS] __read_mostly; ++static struct rapl_pmus *rapl_pmus_pkg; ++static struct rapl_pmus *rapl_pmus_core; ++static unsigned int rapl_pkg_cntr_mask; ++static unsigned int rapl_core_cntr_mask; + static u64 rapl_timer_ms; +-static struct perf_msr *rapl_msrs; ++static struct rapl_model *rapl_model; ++ ++static inline unsigned int get_rapl_pmu_idx(int cpu) ++{ ++ return rapl_pmu_is_pkg_scope() ? topology_logical_package_id(cpu) : ++ topology_logical_die_id(cpu); ++} ++ ++static inline const struct cpumask *get_rapl_pmu_cpumask(int cpu) ++{ ++ return rapl_pmu_is_pkg_scope() ? topology_core_cpumask(cpu) : ++ topology_die_cpumask(cpu); ++} + + static inline struct rapl_pmu *cpu_to_rapl_pmu(unsigned int cpu) + { +- unsigned int rapl_pmu_idx = topology_logical_die_id(cpu); ++ unsigned int rapl_pmu_idx = get_rapl_pmu_idx(cpu); + + /* + * The unsigned check also catches the '-1' return value for non + * existent mappings in the topology map. + */ +- return rapl_pmu_idx < rapl_pmus->nr_rapl_pmu ? rapl_pmus->pmus[rapl_pmu_idx] : NULL; ++ return rapl_pmu_idx < rapl_pmus_pkg->nr_rapl_pmu ? ++ rapl_pmus_pkg->rapl_pmu[rapl_pmu_idx] : NULL; + } + + static inline u64 rapl_read_counter(struct perf_event *event) +@@ -160,7 +196,7 @@ static inline u64 rapl_read_counter(struct perf_event *event) + + static inline u64 rapl_scale(u64 v, int cfg) + { +- if (cfg > NR_RAPL_DOMAINS) { ++ if (cfg > NR_RAPL_PKG_DOMAINS) { + pr_warn("Invalid domain %d, failed to scale data\n", cfg); + return v; + } +@@ -212,34 +248,34 @@ static void rapl_start_hrtimer(struct rapl_pmu *pmu) + + static enum hrtimer_restart rapl_hrtimer_handle(struct hrtimer *hrtimer) + { +- struct rapl_pmu *pmu = container_of(hrtimer, struct rapl_pmu, hrtimer); ++ struct rapl_pmu *rapl_pmu = container_of(hrtimer, struct rapl_pmu, hrtimer); + struct perf_event *event; + unsigned long flags; + +- if (!pmu->n_active) ++ if (!rapl_pmu->n_active) + return HRTIMER_NORESTART; + +- raw_spin_lock_irqsave(&pmu->lock, flags); ++ raw_spin_lock_irqsave(&rapl_pmu->lock, flags); + +- list_for_each_entry(event, &pmu->active_list, active_entry) ++ list_for_each_entry(event, &rapl_pmu->active_list, active_entry) + rapl_event_update(event); + +- raw_spin_unlock_irqrestore(&pmu->lock, flags); ++ raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); + +- hrtimer_forward_now(hrtimer, pmu->timer_interval); ++ hrtimer_forward_now(hrtimer, rapl_pmu->timer_interval); + + return HRTIMER_RESTART; + } + +-static void rapl_hrtimer_init(struct rapl_pmu *pmu) ++static void rapl_hrtimer_init(struct rapl_pmu *rapl_pmu) + { +- struct hrtimer *hr = &pmu->hrtimer; ++ struct hrtimer *hr = &rapl_pmu->hrtimer; + + hrtimer_init(hr, CLOCK_MONOTONIC, HRTIMER_MODE_REL); + hr->function = rapl_hrtimer_handle; + } + +-static void __rapl_pmu_event_start(struct rapl_pmu *pmu, ++static void __rapl_pmu_event_start(struct rapl_pmu *rapl_pmu, + struct perf_event *event) + { + if (WARN_ON_ONCE(!(event->hw.state & PERF_HES_STOPPED))) +@@ -247,39 +283,39 @@ static void __rapl_pmu_event_start(struct rapl_pmu *pmu, + + event->hw.state = 0; + +- list_add_tail(&event->active_entry, &pmu->active_list); ++ list_add_tail(&event->active_entry, &rapl_pmu->active_list); + + local64_set(&event->hw.prev_count, rapl_read_counter(event)); + +- pmu->n_active++; +- if (pmu->n_active == 1) +- rapl_start_hrtimer(pmu); ++ rapl_pmu->n_active++; ++ if (rapl_pmu->n_active == 1) ++ rapl_start_hrtimer(rapl_pmu); + } + + static void rapl_pmu_event_start(struct perf_event *event, int mode) + { +- struct rapl_pmu *pmu = event->pmu_private; ++ struct rapl_pmu *rapl_pmu = event->pmu_private; + unsigned long flags; + +- raw_spin_lock_irqsave(&pmu->lock, flags); +- __rapl_pmu_event_start(pmu, event); +- raw_spin_unlock_irqrestore(&pmu->lock, flags); ++ raw_spin_lock_irqsave(&rapl_pmu->lock, flags); ++ __rapl_pmu_event_start(rapl_pmu, event); ++ raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); + } + + static void rapl_pmu_event_stop(struct perf_event *event, int mode) + { +- struct rapl_pmu *pmu = event->pmu_private; ++ struct rapl_pmu *rapl_pmu = event->pmu_private; + struct hw_perf_event *hwc = &event->hw; + unsigned long flags; + +- raw_spin_lock_irqsave(&pmu->lock, flags); ++ raw_spin_lock_irqsave(&rapl_pmu->lock, flags); + + /* mark event as deactivated and stopped */ + if (!(hwc->state & PERF_HES_STOPPED)) { +- WARN_ON_ONCE(pmu->n_active <= 0); +- pmu->n_active--; +- if (pmu->n_active == 0) +- hrtimer_cancel(&pmu->hrtimer); ++ WARN_ON_ONCE(rapl_pmu->n_active <= 0); ++ rapl_pmu->n_active--; ++ if (rapl_pmu->n_active == 0) ++ hrtimer_cancel(&rapl_pmu->hrtimer); + + list_del(&event->active_entry); + +@@ -297,23 +333,23 @@ static void rapl_pmu_event_stop(struct perf_event *event, int mode) + hwc->state |= PERF_HES_UPTODATE; + } + +- raw_spin_unlock_irqrestore(&pmu->lock, flags); ++ raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); + } + + static int rapl_pmu_event_add(struct perf_event *event, int mode) + { +- struct rapl_pmu *pmu = event->pmu_private; ++ struct rapl_pmu *rapl_pmu = event->pmu_private; + struct hw_perf_event *hwc = &event->hw; + unsigned long flags; + +- raw_spin_lock_irqsave(&pmu->lock, flags); ++ raw_spin_lock_irqsave(&rapl_pmu->lock, flags); + + hwc->state = PERF_HES_UPTODATE | PERF_HES_STOPPED; + + if (mode & PERF_EF_START) +- __rapl_pmu_event_start(pmu, event); ++ __rapl_pmu_event_start(rapl_pmu, event); + +- raw_spin_unlock_irqrestore(&pmu->lock, flags); ++ raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); + + return 0; + } +@@ -327,10 +363,14 @@ static int rapl_pmu_event_init(struct perf_event *event) + { + u64 cfg = event->attr.config & RAPL_EVENT_MASK; + int bit, ret = 0; +- struct rapl_pmu *pmu; ++ struct rapl_pmu *rapl_pmu; ++ struct rapl_pmus *curr_rapl_pmus; + + /* only look at RAPL events */ +- if (event->attr.type != rapl_pmus->pmu.type) ++ if (event->attr.type == rapl_pmus_pkg->pmu.type || ++ (rapl_pmus_core && event->attr.type == rapl_pmus_core->pmu.type)) ++ curr_rapl_pmus = container_of(event->pmu, struct rapl_pmus, pmu); ++ else + return -ENOENT; + + /* check only supported bits are set */ +@@ -340,16 +380,18 @@ static int rapl_pmu_event_init(struct perf_event *event) + if (event->cpu < 0) + return -EINVAL; + +- event->event_caps |= PERF_EV_CAP_READ_ACTIVE_PKG; ++ if (curr_rapl_pmus == rapl_pmus_pkg) ++ event->event_caps |= PERF_EV_CAP_READ_ACTIVE_PKG; + +- if (!cfg || cfg >= NR_RAPL_DOMAINS + 1) ++ if (!cfg || cfg >= NR_RAPL_PKG_DOMAINS + 1) + return -EINVAL; + +- cfg = array_index_nospec((long)cfg, NR_RAPL_DOMAINS + 1); ++ cfg = array_index_nospec((long)cfg, NR_RAPL_PKG_DOMAINS + 1); + bit = cfg - 1; + + /* check event supported */ +- if (!(rapl_cntr_mask & (1 << bit))) ++ if (!(rapl_pkg_cntr_mask & (1 << bit)) && ++ !(rapl_core_cntr_mask & (1 << bit))) + return -EINVAL; + + /* unsupported modes and filters */ +@@ -357,12 +399,18 @@ static int rapl_pmu_event_init(struct perf_event *event) + return -EINVAL; + + /* must be done before validate_group */ +- pmu = cpu_to_rapl_pmu(event->cpu); +- if (!pmu) ++ if (curr_rapl_pmus == rapl_pmus_core) { ++ rapl_pmu = curr_rapl_pmus->rapl_pmu[topology_logical_core_id(event->cpu)]; ++ event->hw.event_base = rapl_model->rapl_core_msrs[bit].msr; ++ } else { ++ rapl_pmu = curr_rapl_pmus->rapl_pmu[get_rapl_pmu_idx(event->cpu)]; ++ event->hw.event_base = rapl_model->rapl_pkg_msrs[bit].msr; ++ } ++ ++ if (!rapl_pmu) + return -EINVAL; +- event->cpu = pmu->cpu; +- event->pmu_private = pmu; +- event->hw.event_base = rapl_msrs[bit].msr; ++ event->cpu = rapl_pmu->cpu; ++ event->pmu_private = rapl_pmu; + event->hw.config = cfg; + event->hw.idx = bit; + +@@ -377,7 +425,7 @@ static void rapl_pmu_event_read(struct perf_event *event) + static ssize_t rapl_get_attr_cpumask(struct device *dev, + struct device_attribute *attr, char *buf) + { +- return cpumap_print_to_pagebuf(true, buf, &rapl_cpu_mask); ++ return cpumap_print_to_pagebuf(true, buf, &rapl_pmus_pkg->cpumask); + } + + static DEVICE_ATTR(cpumask, S_IRUGO, rapl_get_attr_cpumask, NULL); +@@ -391,17 +439,38 @@ static struct attribute_group rapl_pmu_attr_group = { + .attrs = rapl_pmu_attrs, + }; + ++static ssize_t rapl_get_attr_per_core_cpumask(struct device *dev, ++ struct device_attribute *attr, char *buf) ++{ ++ return cpumap_print_to_pagebuf(true, buf, &rapl_pmus_core->cpumask); ++} ++ ++static struct device_attribute dev_attr_per_core_cpumask = __ATTR(cpumask, 0444, ++ rapl_get_attr_per_core_cpumask, ++ NULL); ++ ++static struct attribute *rapl_pmu_per_core_attrs[] = { ++ &dev_attr_per_core_cpumask.attr, ++ NULL, ++}; ++ ++static struct attribute_group rapl_pmu_per_core_attr_group = { ++ .attrs = rapl_pmu_per_core_attrs, ++}; ++ + RAPL_EVENT_ATTR_STR(energy-cores, rapl_cores, "event=0x01"); + RAPL_EVENT_ATTR_STR(energy-pkg , rapl_pkg, "event=0x02"); + RAPL_EVENT_ATTR_STR(energy-ram , rapl_ram, "event=0x03"); + RAPL_EVENT_ATTR_STR(energy-gpu , rapl_gpu, "event=0x04"); + RAPL_EVENT_ATTR_STR(energy-psys, rapl_psys, "event=0x05"); ++RAPL_EVENT_ATTR_STR(energy-per-core, rapl_per_core, "event=0x01"); + + RAPL_EVENT_ATTR_STR(energy-cores.unit, rapl_cores_unit, "Joules"); + RAPL_EVENT_ATTR_STR(energy-pkg.unit , rapl_pkg_unit, "Joules"); + RAPL_EVENT_ATTR_STR(energy-ram.unit , rapl_ram_unit, "Joules"); + RAPL_EVENT_ATTR_STR(energy-gpu.unit , rapl_gpu_unit, "Joules"); + RAPL_EVENT_ATTR_STR(energy-psys.unit, rapl_psys_unit, "Joules"); ++RAPL_EVENT_ATTR_STR(energy-per-core.unit, rapl_per_core_unit, "Joules"); + + /* + * we compute in 0.23 nJ increments regardless of MSR +@@ -411,6 +480,7 @@ RAPL_EVENT_ATTR_STR(energy-pkg.scale, rapl_pkg_scale, "2.3283064365386962890 + RAPL_EVENT_ATTR_STR(energy-ram.scale, rapl_ram_scale, "2.3283064365386962890625e-10"); + RAPL_EVENT_ATTR_STR(energy-gpu.scale, rapl_gpu_scale, "2.3283064365386962890625e-10"); + RAPL_EVENT_ATTR_STR(energy-psys.scale, rapl_psys_scale, "2.3283064365386962890625e-10"); ++RAPL_EVENT_ATTR_STR(energy-per-core.scale, rapl_per_core_scale, "2.3283064365386962890625e-10"); + + /* + * There are no default events, but we need to create +@@ -444,6 +514,13 @@ static const struct attribute_group *rapl_attr_groups[] = { + NULL, + }; + ++static const struct attribute_group *rapl_per_core_attr_groups[] = { ++ &rapl_pmu_per_core_attr_group, ++ &rapl_pmu_format_group, ++ &rapl_pmu_events_group, ++ NULL, ++}; ++ + static struct attribute *rapl_events_cores[] = { + EVENT_PTR(rapl_cores), + EVENT_PTR(rapl_cores_unit), +@@ -504,6 +581,18 @@ static struct attribute_group rapl_events_psys_group = { + .attrs = rapl_events_psys, + }; + ++static struct attribute *rapl_events_per_core[] = { ++ EVENT_PTR(rapl_per_core), ++ EVENT_PTR(rapl_per_core_unit), ++ EVENT_PTR(rapl_per_core_scale), ++ NULL, ++}; ++ ++static struct attribute_group rapl_events_per_core_group = { ++ .name = "events", ++ .attrs = rapl_events_per_core, ++}; ++ + static bool test_msr(int idx, void *data) + { + return test_bit(idx, (unsigned long *) data); +@@ -529,11 +618,11 @@ static struct perf_msr intel_rapl_spr_msrs[] = { + }; + + /* +- * Force to PERF_RAPL_MAX size due to: +- * - perf_msr_probe(PERF_RAPL_MAX) ++ * Force to PERF_RAPL_PKG_EVENTS_MAX size due to: ++ * - perf_msr_probe(PERF_RAPL_PKG_EVENTS_MAX) + * - want to use same event codes across both architectures + */ +-static struct perf_msr amd_rapl_msrs[] = { ++static struct perf_msr amd_rapl_pkg_msrs[] = { + [PERF_RAPL_PP0] = { 0, &rapl_events_cores_group, NULL, false, 0 }, + [PERF_RAPL_PKG] = { MSR_AMD_PKG_ENERGY_STATUS, &rapl_events_pkg_group, test_msr, false, RAPL_MSR_MASK }, + [PERF_RAPL_RAM] = { 0, &rapl_events_ram_group, NULL, false, 0 }, +@@ -541,72 +630,104 @@ static struct perf_msr amd_rapl_msrs[] = { + [PERF_RAPL_PSYS] = { 0, &rapl_events_psys_group, NULL, false, 0 }, + }; + +-static int rapl_cpu_offline(unsigned int cpu) ++static struct perf_msr amd_rapl_core_msrs[] = { ++ [PERF_RAPL_PER_CORE] = { MSR_AMD_CORE_ENERGY_STATUS, &rapl_events_per_core_group, ++ test_msr, false, RAPL_MSR_MASK }, ++}; ++ ++static int __rapl_cpu_offline(struct rapl_pmus *rapl_pmus, unsigned int rapl_pmu_idx, ++ const struct cpumask *event_cpumask, unsigned int cpu) + { +- struct rapl_pmu *pmu = cpu_to_rapl_pmu(cpu); ++ struct rapl_pmu *rapl_pmu = rapl_pmus->rapl_pmu[rapl_pmu_idx]; + int target; + + /* Check if exiting cpu is used for collecting rapl events */ +- if (!cpumask_test_and_clear_cpu(cpu, &rapl_cpu_mask)) ++ if (!cpumask_test_and_clear_cpu(cpu, &rapl_pmus->cpumask)) + return 0; + +- pmu->cpu = -1; ++ rapl_pmu->cpu = -1; + /* Find a new cpu to collect rapl events */ +- target = cpumask_any_but(topology_die_cpumask(cpu), cpu); ++ target = cpumask_any_but(event_cpumask, cpu); + + /* Migrate rapl events to the new target */ + if (target < nr_cpu_ids) { +- cpumask_set_cpu(target, &rapl_cpu_mask); +- pmu->cpu = target; +- perf_pmu_migrate_context(pmu->pmu, cpu, target); ++ cpumask_set_cpu(target, &rapl_pmus->cpumask); ++ rapl_pmu->cpu = target; ++ perf_pmu_migrate_context(rapl_pmu->pmu, cpu, target); + } + return 0; + } + +-static int rapl_cpu_online(unsigned int cpu) ++static int rapl_cpu_offline(unsigned int cpu) ++{ ++ int ret = __rapl_cpu_offline(rapl_pmus_pkg, get_rapl_pmu_idx(cpu), ++ get_rapl_pmu_cpumask(cpu), cpu); ++ ++ if (ret == 0 && rapl_model->core_events) ++ ret = __rapl_cpu_offline(rapl_pmus_core, topology_logical_core_id(cpu), ++ topology_sibling_cpumask(cpu), cpu); ++ ++ return ret; ++} ++ ++static int __rapl_cpu_online(struct rapl_pmus *rapl_pmus, unsigned int rapl_pmu_idx, ++ const struct cpumask *event_cpumask, unsigned int cpu) + { +- struct rapl_pmu *pmu = cpu_to_rapl_pmu(cpu); ++ struct rapl_pmu *rapl_pmu = rapl_pmus->rapl_pmu[rapl_pmu_idx]; + int target; + +- if (!pmu) { +- pmu = kzalloc_node(sizeof(*pmu), GFP_KERNEL, cpu_to_node(cpu)); +- if (!pmu) ++ if (!rapl_pmu) { ++ rapl_pmu = kzalloc_node(sizeof(*rapl_pmu), GFP_KERNEL, cpu_to_node(cpu)); ++ if (!rapl_pmu) + return -ENOMEM; + +- raw_spin_lock_init(&pmu->lock); +- INIT_LIST_HEAD(&pmu->active_list); +- pmu->pmu = &rapl_pmus->pmu; +- pmu->timer_interval = ms_to_ktime(rapl_timer_ms); +- rapl_hrtimer_init(pmu); ++ raw_spin_lock_init(&rapl_pmu->lock); ++ INIT_LIST_HEAD(&rapl_pmu->active_list); ++ rapl_pmu->pmu = &rapl_pmus->pmu; ++ rapl_pmu->timer_interval = ms_to_ktime(rapl_timer_ms); ++ rapl_hrtimer_init(rapl_pmu); + +- rapl_pmus->pmus[topology_logical_die_id(cpu)] = pmu; ++ rapl_pmus->rapl_pmu[rapl_pmu_idx] = rapl_pmu; + } + + /* + * Check if there is an online cpu in the package which collects rapl + * events already. + */ +- target = cpumask_any_and(&rapl_cpu_mask, topology_die_cpumask(cpu)); ++ target = cpumask_any_and(&rapl_pmus->cpumask, event_cpumask); + if (target < nr_cpu_ids) + return 0; + +- cpumask_set_cpu(cpu, &rapl_cpu_mask); +- pmu->cpu = cpu; ++ cpumask_set_cpu(cpu, &rapl_pmus->cpumask); ++ rapl_pmu->cpu = cpu; + return 0; + } + +-static int rapl_check_hw_unit(struct rapl_model *rm) ++static int rapl_cpu_online(unsigned int cpu) ++{ ++ int ret = __rapl_cpu_online(rapl_pmus_pkg, get_rapl_pmu_idx(cpu), ++ get_rapl_pmu_cpumask(cpu), cpu); ++ ++ if (ret == 0 && rapl_model->core_events) ++ ret = __rapl_cpu_online(rapl_pmus_core, topology_logical_core_id(cpu), ++ topology_sibling_cpumask(cpu), cpu); ++ ++ return ret; ++} ++ ++ ++static int rapl_check_hw_unit(void) + { + u64 msr_rapl_power_unit_bits; + int i; + + /* protect rdmsrl() to handle virtualization */ +- if (rdmsrl_safe(rm->msr_power_unit, &msr_rapl_power_unit_bits)) ++ if (rdmsrl_safe(rapl_model->msr_power_unit, &msr_rapl_power_unit_bits)) + return -1; +- for (i = 0; i < NR_RAPL_DOMAINS; i++) ++ for (i = 0; i < NR_RAPL_PKG_DOMAINS; i++) + rapl_hw_unit[i] = (msr_rapl_power_unit_bits >> 8) & 0x1FULL; + +- switch (rm->unit_quirk) { ++ switch (rapl_model->unit_quirk) { + /* + * DRAM domain on HSW server and KNL has fixed energy unit which can be + * different than the unit from power unit MSR. See +@@ -645,22 +766,29 @@ static void __init rapl_advertise(void) + int i; + + pr_info("API unit is 2^-32 Joules, %d fixed counters, %llu ms ovfl timer\n", +- hweight32(rapl_cntr_mask), rapl_timer_ms); ++ hweight32(rapl_pkg_cntr_mask) + hweight32(rapl_core_cntr_mask), rapl_timer_ms); + +- for (i = 0; i < NR_RAPL_DOMAINS; i++) { +- if (rapl_cntr_mask & (1 << i)) { ++ for (i = 0; i < NR_RAPL_PKG_DOMAINS; i++) { ++ if (rapl_pkg_cntr_mask & (1 << i)) { + pr_info("hw unit of domain %s 2^-%d Joules\n", +- rapl_domain_names[i], rapl_hw_unit[i]); ++ rapl_pkg_domain_names[i], rapl_hw_unit[i]); ++ } ++ } ++ ++ for (i = 0; i < NR_RAPL_CORE_DOMAINS; i++) { ++ if (rapl_core_cntr_mask & (1 << i)) { ++ pr_info("hw unit of domain %s 2^-%d Joules\n", ++ rapl_core_domain_names[i], rapl_hw_unit[i]); + } + } + } + +-static void cleanup_rapl_pmus(void) ++static void cleanup_rapl_pmus(struct rapl_pmus *rapl_pmus) + { + int i; + + for (i = 0; i < rapl_pmus->nr_rapl_pmu; i++) +- kfree(rapl_pmus->pmus[i]); ++ kfree(rapl_pmus->rapl_pmu[i]); + kfree(rapl_pmus); + } + +@@ -673,11 +801,17 @@ static const struct attribute_group *rapl_attr_update[] = { + NULL, + }; + +-static int __init init_rapl_pmus(void) ++static const struct attribute_group *rapl_per_core_attr_update[] = { ++ &rapl_events_per_core_group, ++}; ++ ++static int __init init_rapl_pmus(struct rapl_pmus **rapl_pmus_ptr, int nr_rapl_pmu, ++ const struct attribute_group **rapl_attr_groups, ++ const struct attribute_group **rapl_attr_update) + { +- int nr_rapl_pmu = topology_max_packages() * topology_max_dies_per_package(); ++ struct rapl_pmus *rapl_pmus; + +- rapl_pmus = kzalloc(struct_size(rapl_pmus, pmus, nr_rapl_pmu), GFP_KERNEL); ++ rapl_pmus = kzalloc(struct_size(rapl_pmus, rapl_pmu, nr_rapl_pmu), GFP_KERNEL); + if (!rapl_pmus) + return -ENOMEM; + +@@ -693,75 +827,80 @@ static int __init init_rapl_pmus(void) + rapl_pmus->pmu.read = rapl_pmu_event_read; + rapl_pmus->pmu.module = THIS_MODULE; + rapl_pmus->pmu.capabilities = PERF_PMU_CAP_NO_EXCLUDE; ++ ++ *rapl_pmus_ptr = rapl_pmus; ++ + return 0; + } + + static struct rapl_model model_snb = { +- .events = BIT(PERF_RAPL_PP0) | ++ .pkg_events = BIT(PERF_RAPL_PP0) | + BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_PP1), + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_msrs, ++ .rapl_pkg_msrs = intel_rapl_msrs, + }; + + static struct rapl_model model_snbep = { +- .events = BIT(PERF_RAPL_PP0) | ++ .pkg_events = BIT(PERF_RAPL_PP0) | + BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_RAM), + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_msrs, ++ .rapl_pkg_msrs = intel_rapl_msrs, + }; + + static struct rapl_model model_hsw = { +- .events = BIT(PERF_RAPL_PP0) | ++ .pkg_events = BIT(PERF_RAPL_PP0) | + BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_RAM) | + BIT(PERF_RAPL_PP1), + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_msrs, ++ .rapl_pkg_msrs = intel_rapl_msrs, + }; + + static struct rapl_model model_hsx = { +- .events = BIT(PERF_RAPL_PP0) | ++ .pkg_events = BIT(PERF_RAPL_PP0) | + BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_RAM), + .unit_quirk = RAPL_UNIT_QUIRK_INTEL_HSW, + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_msrs, ++ .rapl_pkg_msrs = intel_rapl_msrs, + }; + + static struct rapl_model model_knl = { +- .events = BIT(PERF_RAPL_PKG) | ++ .pkg_events = BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_RAM), + .unit_quirk = RAPL_UNIT_QUIRK_INTEL_HSW, + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_msrs, ++ .rapl_pkg_msrs = intel_rapl_msrs, + }; + + static struct rapl_model model_skl = { +- .events = BIT(PERF_RAPL_PP0) | ++ .pkg_events = BIT(PERF_RAPL_PP0) | + BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_RAM) | + BIT(PERF_RAPL_PP1) | + BIT(PERF_RAPL_PSYS), + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_msrs, ++ .rapl_pkg_msrs = intel_rapl_msrs, + }; + + static struct rapl_model model_spr = { +- .events = BIT(PERF_RAPL_PP0) | ++ .pkg_events = BIT(PERF_RAPL_PP0) | + BIT(PERF_RAPL_PKG) | + BIT(PERF_RAPL_RAM) | + BIT(PERF_RAPL_PSYS), + .unit_quirk = RAPL_UNIT_QUIRK_INTEL_SPR, + .msr_power_unit = MSR_RAPL_POWER_UNIT, +- .rapl_msrs = intel_rapl_spr_msrs, ++ .rapl_pkg_msrs = intel_rapl_spr_msrs, + }; + + static struct rapl_model model_amd_hygon = { +- .events = BIT(PERF_RAPL_PKG), ++ .pkg_events = BIT(PERF_RAPL_PKG), ++ .core_events = BIT(PERF_RAPL_PER_CORE), + .msr_power_unit = MSR_AMD_RAPL_POWER_UNIT, +- .rapl_msrs = amd_rapl_msrs, ++ .rapl_pkg_msrs = amd_rapl_pkg_msrs, ++ .rapl_core_msrs = amd_rapl_core_msrs, + }; + + static const struct x86_cpu_id rapl_model_match[] __initconst = { +@@ -817,28 +956,47 @@ MODULE_DEVICE_TABLE(x86cpu, rapl_model_match); + static int __init rapl_pmu_init(void) + { + const struct x86_cpu_id *id; +- struct rapl_model *rm; + int ret; ++ int nr_rapl_pmu = topology_max_packages() * topology_max_dies_per_package(); ++ int nr_cores = topology_max_packages() * topology_num_cores_per_package(); ++ ++ if (rapl_pmu_is_pkg_scope()) ++ nr_rapl_pmu = topology_max_packages(); + + id = x86_match_cpu(rapl_model_match); + if (!id) + return -ENODEV; + +- rm = (struct rapl_model *) id->driver_data; +- +- rapl_msrs = rm->rapl_msrs; ++ rapl_model = (struct rapl_model *) id->driver_data; + +- rapl_cntr_mask = perf_msr_probe(rapl_msrs, PERF_RAPL_MAX, +- false, (void *) &rm->events); ++ rapl_pkg_cntr_mask = perf_msr_probe(rapl_model->rapl_pkg_msrs, PERF_RAPL_PKG_EVENTS_MAX, ++ false, (void *) &rapl_model->pkg_events); + +- ret = rapl_check_hw_unit(rm); ++ ret = rapl_check_hw_unit(); + if (ret) + return ret; + +- ret = init_rapl_pmus(); ++ ret = init_rapl_pmus(&rapl_pmus_pkg, nr_rapl_pmu, rapl_attr_groups, rapl_attr_update); + if (ret) + return ret; + ++ if (rapl_model->core_events) { ++ rapl_core_cntr_mask = perf_msr_probe(rapl_model->rapl_core_msrs, ++ PERF_RAPL_CORE_EVENTS_MAX, false, ++ (void *) &rapl_model->core_events); ++ ++ ret = init_rapl_pmus(&rapl_pmus_core, nr_cores, ++ rapl_per_core_attr_groups, rapl_per_core_attr_update); ++ if (ret) { ++ /* ++ * If initialization of per_core PMU fails, reset per_core ++ * flag, and continue with power PMU initialization. ++ */ ++ pr_warn("Per-core PMU initialization failed (%d)\n", ret); ++ rapl_model->core_events = 0UL; ++ } ++ } ++ + /* + * Install callbacks. Core will call them for each online cpu. + */ +@@ -848,10 +1006,24 @@ static int __init rapl_pmu_init(void) + if (ret) + goto out; + +- ret = perf_pmu_register(&rapl_pmus->pmu, "power", -1); ++ ret = perf_pmu_register(&rapl_pmus_pkg->pmu, "power", -1); + if (ret) + goto out1; + ++ if (rapl_model->core_events) { ++ ret = perf_pmu_register(&rapl_pmus_core->pmu, "power_per_core", -1); ++ if (ret) { ++ /* ++ * If registration of per_core PMU fails, cleanup per_core PMU ++ * variables, reset the per_core flag and keep the ++ * power PMU untouched. ++ */ ++ pr_warn("Per-core PMU registration failed (%d)\n", ret); ++ cleanup_rapl_pmus(rapl_pmus_core); ++ rapl_model->core_events = 0UL; ++ } ++ } ++ + rapl_advertise(); + return 0; + +@@ -859,7 +1031,7 @@ static int __init rapl_pmu_init(void) + cpuhp_remove_state(CPUHP_AP_PERF_X86_RAPL_ONLINE); + out: + pr_warn("Initialization failed (%d), disabled\n", ret); +- cleanup_rapl_pmus(); ++ cleanup_rapl_pmus(rapl_pmus_pkg); + return ret; + } + module_init(rapl_pmu_init); +@@ -867,7 +1039,11 @@ module_init(rapl_pmu_init); + static void __exit intel_rapl_exit(void) + { + cpuhp_remove_state_nocalls(CPUHP_AP_PERF_X86_RAPL_ONLINE); +- perf_pmu_unregister(&rapl_pmus->pmu); +- cleanup_rapl_pmus(); ++ perf_pmu_unregister(&rapl_pmus_pkg->pmu); ++ cleanup_rapl_pmus(rapl_pmus_pkg); ++ if (rapl_model->core_events) { ++ perf_pmu_unregister(&rapl_pmus_core->pmu); ++ cleanup_rapl_pmus(rapl_pmus_core); ++ } + } + module_exit(intel_rapl_exit); +diff --git a/arch/x86/include/asm/processor.h b/arch/x86/include/asm/processor.h +index cb4f6c513c48..1ffe4260bef6 100644 +--- a/arch/x86/include/asm/processor.h ++++ b/arch/x86/include/asm/processor.h +@@ -98,6 +98,7 @@ struct cpuinfo_topology { + // Logical ID mappings + u32 logical_pkg_id; + u32 logical_die_id; ++ u32 logical_core_id; + + // AMD Node ID and Nodes per Package info + u32 amd_node_id; +diff --git a/arch/x86/include/asm/topology.h b/arch/x86/include/asm/topology.h +index e5b203fe7956..8c2fea7dd065 100644 +--- a/arch/x86/include/asm/topology.h ++++ b/arch/x86/include/asm/topology.h +@@ -137,6 +137,7 @@ extern const struct cpumask *cpu_clustergroup_mask(int cpu); + #define topology_logical_package_id(cpu) (cpu_data(cpu).topo.logical_pkg_id) + #define topology_physical_package_id(cpu) (cpu_data(cpu).topo.pkg_id) + #define topology_logical_die_id(cpu) (cpu_data(cpu).topo.logical_die_id) ++#define topology_logical_core_id(cpu) (cpu_data(cpu).topo.logical_core_id) + #define topology_die_id(cpu) (cpu_data(cpu).topo.die_id) + #define topology_core_id(cpu) (cpu_data(cpu).topo.core_id) + #define topology_ppin(cpu) (cpu_data(cpu).ppin) +diff --git a/arch/x86/kernel/cpu/debugfs.c b/arch/x86/kernel/cpu/debugfs.c +index 3baf3e435834..b1eb6d7828db 100644 +--- a/arch/x86/kernel/cpu/debugfs.c ++++ b/arch/x86/kernel/cpu/debugfs.c +@@ -24,6 +24,7 @@ static int cpu_debug_show(struct seq_file *m, void *p) + seq_printf(m, "core_id: %u\n", c->topo.core_id); + seq_printf(m, "logical_pkg_id: %u\n", c->topo.logical_pkg_id); + seq_printf(m, "logical_die_id: %u\n", c->topo.logical_die_id); ++ seq_printf(m, "logical_core_id: %u\n", c->topo.logical_core_id); + seq_printf(m, "llc_id: %u\n", c->topo.llc_id); + seq_printf(m, "l2c_id: %u\n", c->topo.l2c_id); + seq_printf(m, "amd_node_id: %u\n", c->topo.amd_node_id); +diff --git a/arch/x86/kernel/cpu/topology_common.c b/arch/x86/kernel/cpu/topology_common.c +index 9a6069e7133c..23722aa21e2f 100644 +--- a/arch/x86/kernel/cpu/topology_common.c ++++ b/arch/x86/kernel/cpu/topology_common.c +@@ -151,6 +151,7 @@ static void topo_set_ids(struct topo_scan *tscan, bool early) + if (!early) { + c->topo.logical_pkg_id = topology_get_logical_id(apicid, TOPO_PKG_DOMAIN); + c->topo.logical_die_id = topology_get_logical_id(apicid, TOPO_DIE_DOMAIN); ++ c->topo.logical_core_id = topology_get_logical_id(apicid, TOPO_CORE_DOMAIN); + } + + /* Package relative core ID */ +-- +2.46.0 + +From f5b9118439a16a5cd3ce3f611ae27dd41f73b146 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:20:14 +0200 +Subject: [PATCH 11/12] t2 + +Signed-off-by: Peter Jung +--- + .../ABI/testing/sysfs-driver-hid-appletb-kbd | 13 + + Documentation/core-api/printk-formats.rst | 32 + + MAINTAINERS | 12 + + drivers/acpi/video_detect.c | 16 + + drivers/firmware/efi/libstub/Makefile | 2 +- + drivers/firmware/efi/libstub/arm64.c | 3 +- + drivers/firmware/efi/libstub/efistub.h | 9 +- + drivers/firmware/efi/libstub/smbios.c | 43 +- + drivers/firmware/efi/libstub/x86-stub.c | 71 +- + drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 3 + + drivers/gpu/drm/drm_format_helper.c | 54 + + drivers/gpu/drm/i915/display/intel_ddi.c | 4 + + drivers/gpu/drm/i915/display/intel_fbdev.c | 6 +- + drivers/gpu/drm/i915/display/intel_quirks.c | 15 + + drivers/gpu/drm/i915/display/intel_quirks.h | 1 + + .../gpu/drm/tests/drm_format_helper_test.c | 81 ++ + drivers/gpu/drm/tiny/Kconfig | 12 + + drivers/gpu/drm/tiny/Makefile | 1 + + drivers/gpu/drm/tiny/appletbdrm.c | 624 +++++++++ + drivers/gpu/vga/vga_switcheroo.c | 7 +- + drivers/hid/Kconfig | 22 + + drivers/hid/Makefile | 2 + + drivers/hid/hid-apple.c | 87 ++ + drivers/hid/hid-appletb-bl.c | 206 +++ + drivers/hid/hid-appletb-kbd.c | 432 +++++++ + drivers/hid/hid-core.c | 25 + + drivers/hid/hid-google-hammer.c | 27 +- + drivers/hid/hid-multitouch.c | 60 +- + drivers/hid/hid-quirks.c | 8 +- + drivers/hwmon/applesmc.c | 1138 ++++++++++++----- + drivers/input/mouse/bcm5974.c | 138 ++ + drivers/pci/vgaarb.c | 1 + + drivers/platform/x86/apple-gmux.c | 18 + + drivers/staging/Kconfig | 2 + + drivers/staging/Makefile | 1 + + drivers/staging/apple-bce/Kconfig | 18 + + drivers/staging/apple-bce/Makefile | 28 + + drivers/staging/apple-bce/apple_bce.c | 445 +++++++ + drivers/staging/apple-bce/apple_bce.h | 38 + + drivers/staging/apple-bce/audio/audio.c | 711 ++++++++++ + drivers/staging/apple-bce/audio/audio.h | 125 ++ + drivers/staging/apple-bce/audio/description.h | 42 + + drivers/staging/apple-bce/audio/pcm.c | 308 +++++ + drivers/staging/apple-bce/audio/pcm.h | 16 + + drivers/staging/apple-bce/audio/protocol.c | 347 +++++ + drivers/staging/apple-bce/audio/protocol.h | 147 +++ + .../staging/apple-bce/audio/protocol_bce.c | 226 ++++ + .../staging/apple-bce/audio/protocol_bce.h | 72 ++ + drivers/staging/apple-bce/mailbox.c | 151 +++ + drivers/staging/apple-bce/mailbox.h | 53 + + drivers/staging/apple-bce/queue.c | 390 ++++++ + drivers/staging/apple-bce/queue.h | 177 +++ + drivers/staging/apple-bce/queue_dma.c | 220 ++++ + drivers/staging/apple-bce/queue_dma.h | 50 + + drivers/staging/apple-bce/vhci/command.h | 204 +++ + drivers/staging/apple-bce/vhci/queue.c | 268 ++++ + drivers/staging/apple-bce/vhci/queue.h | 76 ++ + drivers/staging/apple-bce/vhci/transfer.c | 661 ++++++++++ + drivers/staging/apple-bce/vhci/transfer.h | 73 ++ + drivers/staging/apple-bce/vhci/vhci.c | 759 +++++++++++ + drivers/staging/apple-bce/vhci/vhci.h | 52 + + drivers/usb/core/driver.c | 14 + + drivers/usb/storage/uas.c | 5 +- + include/drm/drm_format_helper.h | 3 + + include/linux/efi.h | 5 +- + include/linux/hid.h | 2 + + include/linux/usb.h | 3 + + lib/test_printf.c | 20 +- + lib/vsprintf.c | 36 +- + scripts/checkpatch.pl | 2 +- + 70 files changed, 8530 insertions(+), 393 deletions(-) + create mode 100644 Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd + create mode 100644 drivers/gpu/drm/tiny/appletbdrm.c + create mode 100644 drivers/hid/hid-appletb-bl.c + create mode 100644 drivers/hid/hid-appletb-kbd.c + create mode 100644 drivers/staging/apple-bce/Kconfig + create mode 100644 drivers/staging/apple-bce/Makefile + create mode 100644 drivers/staging/apple-bce/apple_bce.c + create mode 100644 drivers/staging/apple-bce/apple_bce.h + create mode 100644 drivers/staging/apple-bce/audio/audio.c + create mode 100644 drivers/staging/apple-bce/audio/audio.h + create mode 100644 drivers/staging/apple-bce/audio/description.h + create mode 100644 drivers/staging/apple-bce/audio/pcm.c + create mode 100644 drivers/staging/apple-bce/audio/pcm.h + create mode 100644 drivers/staging/apple-bce/audio/protocol.c + create mode 100644 drivers/staging/apple-bce/audio/protocol.h + create mode 100644 drivers/staging/apple-bce/audio/protocol_bce.c + create mode 100644 drivers/staging/apple-bce/audio/protocol_bce.h + create mode 100644 drivers/staging/apple-bce/mailbox.c + create mode 100644 drivers/staging/apple-bce/mailbox.h + create mode 100644 drivers/staging/apple-bce/queue.c + create mode 100644 drivers/staging/apple-bce/queue.h + create mode 100644 drivers/staging/apple-bce/queue_dma.c + create mode 100644 drivers/staging/apple-bce/queue_dma.h + create mode 100644 drivers/staging/apple-bce/vhci/command.h + create mode 100644 drivers/staging/apple-bce/vhci/queue.c + create mode 100644 drivers/staging/apple-bce/vhci/queue.h + create mode 100644 drivers/staging/apple-bce/vhci/transfer.c + create mode 100644 drivers/staging/apple-bce/vhci/transfer.h + create mode 100644 drivers/staging/apple-bce/vhci/vhci.c + create mode 100644 drivers/staging/apple-bce/vhci/vhci.h + +diff --git a/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd b/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd +new file mode 100644 +index 000000000000..2a19584d091e +--- /dev/null ++++ b/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd +@@ -0,0 +1,13 @@ ++What: /sys/bus/hid/drivers/hid-appletb-kbd//mode ++Date: September, 2023 ++KernelVersion: 6.5 ++Contact: linux-input@vger.kernel.org ++Description: ++ The set of keys displayed on the Touch Bar. ++ Valid values are: ++ == ================= ++ 0 Escape key only ++ 1 Function keys ++ 2 Media/brightness keys ++ 3 None ++ == ================= +diff --git a/Documentation/core-api/printk-formats.rst b/Documentation/core-api/printk-formats.rst +index 4451ef501936..c726a846f752 100644 +--- a/Documentation/core-api/printk-formats.rst ++++ b/Documentation/core-api/printk-formats.rst +@@ -632,6 +632,38 @@ Examples:: + %p4cc Y10 little-endian (0x20303159) + %p4cc NV12 big-endian (0xb231564e) + ++Generic FourCC code ++------------------- ++ ++:: ++ %p4c[hnbl] gP00 (0x67503030) ++ ++Print a generic FourCC code, as both ASCII characters and its numerical ++value as hexadecimal. ++ ++The additional ``h``, ``r``, ``b``, and ``l`` specifiers are used to specify ++host, reversed, big or little endian order data respectively. Host endian ++order means the data is interpreted as a 32-bit integer and the most ++significant byte is printed first; that is, the character code as printed ++matches the byte order stored in memory on big-endian systems, and is reversed ++on little-endian systems. ++ ++Passed by reference. ++ ++Examples for a little-endian machine, given &(u32)0x67503030:: ++ ++ %p4ch gP00 (0x67503030) ++ %p4cl gP00 (0x67503030) ++ %p4cb 00Pg (0x30305067) ++ %p4cr 00Pg (0x30305067) ++ ++Examples for a big-endian machine, given &(u32)0x67503030:: ++ ++ %p4ch gP00 (0x67503030) ++ %p4cl 00Pg (0x30305067) ++ %p4cb gP00 (0x67503030) ++ %p4cr 00Pg (0x30305067) ++ + Rust + ---- + +diff --git a/MAINTAINERS b/MAINTAINERS +index 4112729fc23a..064156d69e75 100644 +--- a/MAINTAINERS ++++ b/MAINTAINERS +@@ -6728,6 +6728,12 @@ S: Supported + T: git https://gitlab.freedesktop.org/drm/misc/kernel.git + F: drivers/gpu/drm/sun4i/sun8i* + ++DRM DRIVER FOR APPLE TOUCH BARS ++M: Kerem Karabay ++L: dri-devel@lists.freedesktop.org ++S: Maintained ++F: drivers/gpu/drm/tiny/appletbdrm.c ++ + DRM DRIVER FOR ARM PL111 CLCD + S: Orphan + T: git https://gitlab.freedesktop.org/drm/misc/kernel.git +@@ -9733,6 +9739,12 @@ F: include/linux/pm.h + F: include/linux/suspend.h + F: kernel/power/ + ++HID APPLE TOUCH BAR DRIVERS ++M: Kerem Karabay ++L: linux-input@vger.kernel.org ++S: Maintained ++F: drivers/hid/hid-appletb-* ++ + HID CORE LAYER + M: Jiri Kosina + M: Benjamin Tissoires +diff --git a/drivers/acpi/video_detect.c b/drivers/acpi/video_detect.c +index 2cc3821b2b16..c11cbe5b6eaa 100644 +--- a/drivers/acpi/video_detect.c ++++ b/drivers/acpi/video_detect.c +@@ -539,6 +539,14 @@ static const struct dmi_system_id video_detect_dmi_table[] = { + DMI_MATCH(DMI_PRODUCT_NAME, "iMac12,2"), + }, + }, ++ { ++ .callback = video_detect_force_native, ++ /* Apple MacBook Air 9,1 */ ++ .matches = { ++ DMI_MATCH(DMI_SYS_VENDOR, "Apple Inc."), ++ DMI_MATCH(DMI_PRODUCT_NAME, "MacBookAir9,1"), ++ }, ++ }, + { + /* https://bugzilla.redhat.com/show_bug.cgi?id=1217249 */ + .callback = video_detect_force_native, +@@ -548,6 +556,14 @@ static const struct dmi_system_id video_detect_dmi_table[] = { + DMI_MATCH(DMI_PRODUCT_NAME, "MacBookPro12,1"), + }, + }, ++ { ++ .callback = video_detect_force_native, ++ /* Apple MacBook Pro 16,2 */ ++ .matches = { ++ DMI_MATCH(DMI_SYS_VENDOR, "Apple Inc."), ++ DMI_MATCH(DMI_PRODUCT_NAME, "MacBookPro16,2"), ++ }, ++ }, + { + .callback = video_detect_force_native, + /* Dell Inspiron N4010 */ +diff --git a/drivers/firmware/efi/libstub/Makefile b/drivers/firmware/efi/libstub/Makefile +index 06f0428a723c..1f32d6cf98d6 100644 +--- a/drivers/firmware/efi/libstub/Makefile ++++ b/drivers/firmware/efi/libstub/Makefile +@@ -76,7 +76,7 @@ lib-$(CONFIG_EFI_GENERIC_STUB) += efi-stub.o string.o intrinsics.o systable.o \ + + lib-$(CONFIG_ARM) += arm32-stub.o + lib-$(CONFIG_ARM64) += kaslr.o arm64.o arm64-stub.o smbios.o +-lib-$(CONFIG_X86) += x86-stub.o ++lib-$(CONFIG_X86) += x86-stub.o smbios.o + lib-$(CONFIG_X86_64) += x86-5lvl.o + lib-$(CONFIG_RISCV) += kaslr.o riscv.o riscv-stub.o + lib-$(CONFIG_LOONGARCH) += loongarch.o loongarch-stub.o +diff --git a/drivers/firmware/efi/libstub/arm64.c b/drivers/firmware/efi/libstub/arm64.c +index 446e35eaf3d9..e57cd3de0a00 100644 +--- a/drivers/firmware/efi/libstub/arm64.c ++++ b/drivers/firmware/efi/libstub/arm64.c +@@ -39,8 +39,7 @@ static bool system_needs_vamap(void) + static char const emag[] = "eMAG"; + + default: +- version = efi_get_smbios_string(&record->header, 4, +- processor_version); ++ version = efi_get_smbios_string(record, processor_version); + if (!version || (strncmp(version, altra, sizeof(altra) - 1) && + strncmp(version, emag, sizeof(emag) - 1))) + break; +diff --git a/drivers/firmware/efi/libstub/efistub.h b/drivers/firmware/efi/libstub/efistub.h +index 27abb4ce0291..d33ccbc4a2c6 100644 +--- a/drivers/firmware/efi/libstub/efistub.h ++++ b/drivers/firmware/efi/libstub/efistub.h +@@ -1204,14 +1204,13 @@ struct efi_smbios_type4_record { + u16 thread_enabled; + }; + +-#define efi_get_smbios_string(__record, __type, __name) ({ \ +- int off = offsetof(struct efi_smbios_type ## __type ## _record, \ +- __name); \ +- __efi_get_smbios_string((__record), __type, off); \ ++#define efi_get_smbios_string(__record, __field) ({ \ ++ __typeof__(__record) __rec = __record; \ ++ __efi_get_smbios_string(&__rec->header, &__rec->__field); \ + }) + + const u8 *__efi_get_smbios_string(const struct efi_smbios_record *record, +- u8 type, int offset); ++ const u8 *offset); + + void efi_remap_image(unsigned long image_base, unsigned alloc_size, + unsigned long code_size); +diff --git a/drivers/firmware/efi/libstub/smbios.c b/drivers/firmware/efi/libstub/smbios.c +index c217de2cc8d5..f31410d7e7e1 100644 +--- a/drivers/firmware/efi/libstub/smbios.c ++++ b/drivers/firmware/efi/libstub/smbios.c +@@ -6,20 +6,31 @@ + + #include "efistub.h" + +-typedef struct efi_smbios_protocol efi_smbios_protocol_t; +- +-struct efi_smbios_protocol { +- efi_status_t (__efiapi *add)(efi_smbios_protocol_t *, efi_handle_t, +- u16 *, struct efi_smbios_record *); +- efi_status_t (__efiapi *update_string)(efi_smbios_protocol_t *, u16 *, +- unsigned long *, u8 *); +- efi_status_t (__efiapi *remove)(efi_smbios_protocol_t *, u16); +- efi_status_t (__efiapi *get_next)(efi_smbios_protocol_t *, u16 *, u8 *, +- struct efi_smbios_record **, +- efi_handle_t *); +- +- u8 major_version; +- u8 minor_version; ++typedef union efi_smbios_protocol efi_smbios_protocol_t; ++ ++union efi_smbios_protocol { ++ struct { ++ efi_status_t (__efiapi *add)(efi_smbios_protocol_t *, efi_handle_t, ++ u16 *, struct efi_smbios_record *); ++ efi_status_t (__efiapi *update_string)(efi_smbios_protocol_t *, u16 *, ++ unsigned long *, u8 *); ++ efi_status_t (__efiapi *remove)(efi_smbios_protocol_t *, u16); ++ efi_status_t (__efiapi *get_next)(efi_smbios_protocol_t *, u16 *, u8 *, ++ struct efi_smbios_record **, ++ efi_handle_t *); ++ ++ u8 major_version; ++ u8 minor_version; ++ }; ++ struct { ++ u32 add; ++ u32 update_string; ++ u32 remove; ++ u32 get_next; ++ ++ u8 major_version; ++ u8 minor_version; ++ } mixed_mode; + }; + + const struct efi_smbios_record *efi_get_smbios_record(u8 type) +@@ -38,7 +49,7 @@ const struct efi_smbios_record *efi_get_smbios_record(u8 type) + } + + const u8 *__efi_get_smbios_string(const struct efi_smbios_record *record, +- u8 type, int offset) ++ const u8 *offset) + { + const u8 *strtable; + +@@ -46,7 +57,7 @@ const u8 *__efi_get_smbios_string(const struct efi_smbios_record *record, + return NULL; + + strtable = (u8 *)record + record->length; +- for (int i = 1; i < ((u8 *)record)[offset]; i++) { ++ for (int i = 1; i < *offset; i++) { + int len = strlen(strtable); + + if (!len) +diff --git a/drivers/firmware/efi/libstub/x86-stub.c b/drivers/firmware/efi/libstub/x86-stub.c +index 99d39eda5134..0a2342c0bc16 100644 +--- a/drivers/firmware/efi/libstub/x86-stub.c ++++ b/drivers/firmware/efi/libstub/x86-stub.c +@@ -225,6 +225,68 @@ static void retrieve_apple_device_properties(struct boot_params *boot_params) + } + } + ++static bool apple_match_product_name(void) ++{ ++ static const char type1_product_matches[][15] = { ++ "MacBookPro11,3", ++ "MacBookPro11,5", ++ "MacBookPro13,3", ++ "MacBookPro14,3", ++ "MacBookPro15,1", ++ "MacBookPro15,3", ++ "MacBookPro16,1", ++ "MacBookPro16,4", ++ }; ++ const struct efi_smbios_type1_record *record; ++ const u8 *product; ++ ++ record = (struct efi_smbios_type1_record *)efi_get_smbios_record(1); ++ if (!record) ++ return false; ++ ++ product = efi_get_smbios_string(record, product_name); ++ if (!product) ++ return false; ++ ++ for (int i = 0; i < ARRAY_SIZE(type1_product_matches); i++) { ++ if (!strcmp(product, type1_product_matches[i])) ++ return true; ++ } ++ ++ return false; ++} ++ ++static void apple_set_os(void) ++{ ++ struct { ++ unsigned long version; ++ efi_status_t (__efiapi *set_os_version)(const char *); ++ efi_status_t (__efiapi *set_os_vendor)(const char *); ++ } *set_os; ++ efi_status_t status; ++ ++ if (!efi_is_64bit() || !apple_match_product_name()) ++ return; ++ ++ status = efi_bs_call(locate_protocol, &APPLE_SET_OS_PROTOCOL_GUID, NULL, ++ (void **)&set_os); ++ if (status != EFI_SUCCESS) ++ return; ++ ++ if (set_os->version >= 2) { ++ status = set_os->set_os_vendor("Apple Inc."); ++ if (status != EFI_SUCCESS) ++ efi_err("Failed to set OS vendor via apple_set_os\n"); ++ } ++ ++ if (set_os->version > 0) { ++ /* The version being set doesn't seem to matter */ ++ status = set_os->set_os_version("Mac OS X 10.9"); ++ if (status != EFI_SUCCESS) ++ efi_err("Failed to set OS version via apple_set_os\n"); ++ } ++} ++ + efi_status_t efi_adjust_memory_range_protection(unsigned long start, + unsigned long size) + { +@@ -335,9 +397,12 @@ static const efi_char16_t apple[] = L"Apple"; + + static void setup_quirks(struct boot_params *boot_params) + { +- if (IS_ENABLED(CONFIG_APPLE_PROPERTIES) && +- !memcmp(efistub_fw_vendor(), apple, sizeof(apple))) +- retrieve_apple_device_properties(boot_params); ++ if (!memcmp(efistub_fw_vendor(), apple, sizeof(apple))) { ++ if (IS_ENABLED(CONFIG_APPLE_PROPERTIES)) ++ retrieve_apple_device_properties(boot_params); ++ ++ apple_set_os(); ++ } + } + + /* +diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +index bb0b636d0d75..a05ed98da785 100644 +--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c ++++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +@@ -2211,6 +2211,9 @@ static int amdgpu_pci_probe(struct pci_dev *pdev, + int ret, retry = 0, i; + bool supports_atomic = false; + ++ if (vga_switcheroo_client_probe_defer(pdev)) ++ return -EPROBE_DEFER; ++ + /* skip devices which are owned by radeon */ + for (i = 0; i < ARRAY_SIZE(amdgpu_unsupported_pciidlist); i++) { + if (amdgpu_unsupported_pciidlist[i] == pdev->device) +diff --git a/drivers/gpu/drm/drm_format_helper.c b/drivers/gpu/drm/drm_format_helper.c +index b1be458ed4dd..28c0e76a1e88 100644 +--- a/drivers/gpu/drm/drm_format_helper.c ++++ b/drivers/gpu/drm/drm_format_helper.c +@@ -702,6 +702,57 @@ void drm_fb_xrgb8888_to_rgb888(struct iosys_map *dst, const unsigned int *dst_pi + } + EXPORT_SYMBOL(drm_fb_xrgb8888_to_rgb888); + ++static void drm_fb_xrgb8888_to_bgr888_line(void *dbuf, const void *sbuf, unsigned int pixels) ++{ ++ u8 *dbuf8 = dbuf; ++ const __le32 *sbuf32 = sbuf; ++ unsigned int x; ++ u32 pix; ++ ++ for (x = 0; x < pixels; x++) { ++ pix = le32_to_cpu(sbuf32[x]); ++ /* write red-green-blue to output in little endianness */ ++ *dbuf8++ = (pix & 0x00FF0000) >> 16; ++ *dbuf8++ = (pix & 0x0000FF00) >> 8; ++ *dbuf8++ = (pix & 0x000000FF) >> 0; ++ } ++} ++ ++/** ++ * drm_fb_xrgb8888_to_bgr888 - Convert XRGB8888 to BGR888 clip buffer ++ * @dst: Array of BGR888 destination buffers ++ * @dst_pitch: Array of numbers of bytes between the start of two consecutive scanlines ++ * within @dst; can be NULL if scanlines are stored next to each other. ++ * @src: Array of XRGB8888 source buffers ++ * @fb: DRM framebuffer ++ * @clip: Clip rectangle area to copy ++ * @state: Transform and conversion state ++ * ++ * This function copies parts of a framebuffer to display memory and converts the ++ * color format during the process. Destination and framebuffer formats must match. The ++ * parameters @dst, @dst_pitch and @src refer to arrays. Each array must have at ++ * least as many entries as there are planes in @fb's format. Each entry stores the ++ * value for the format's respective color plane at the same index. ++ * ++ * This function does not apply clipping on @dst (i.e. the destination is at the ++ * top-left corner). ++ * ++ * Drivers can use this function for BGR888 devices that don't natively ++ * support XRGB8888. ++ */ ++void drm_fb_xrgb8888_to_bgr888(struct iosys_map *dst, const unsigned int *dst_pitch, ++ const struct iosys_map *src, const struct drm_framebuffer *fb, ++ const struct drm_rect *clip, struct drm_format_conv_state *state) ++{ ++ static const u8 dst_pixsize[DRM_FORMAT_MAX_PLANES] = { ++ 3, ++ }; ++ ++ drm_fb_xfrm(dst, dst_pitch, dst_pixsize, src, fb, clip, false, state, ++ drm_fb_xrgb8888_to_bgr888_line); ++} ++EXPORT_SYMBOL(drm_fb_xrgb8888_to_bgr888); ++ + static void drm_fb_xrgb8888_to_argb8888_line(void *dbuf, const void *sbuf, unsigned int pixels) + { + __le32 *dbuf32 = dbuf; +@@ -1035,6 +1086,9 @@ int drm_fb_blit(struct iosys_map *dst, const unsigned int *dst_pitch, uint32_t d + } else if (dst_format == DRM_FORMAT_RGB888) { + drm_fb_xrgb8888_to_rgb888(dst, dst_pitch, src, fb, clip, state); + return 0; ++ } else if (dst_format == DRM_FORMAT_BGR888) { ++ drm_fb_xrgb8888_to_bgr888(dst, dst_pitch, src, fb, clip, state); ++ return 0; + } else if (dst_format == DRM_FORMAT_ARGB8888) { + drm_fb_xrgb8888_to_argb8888(dst, dst_pitch, src, fb, clip, state); + return 0; +diff --git a/drivers/gpu/drm/i915/display/intel_ddi.c b/drivers/gpu/drm/i915/display/intel_ddi.c +index 6bff169fa8d4..8d80ae00b838 100644 +--- a/drivers/gpu/drm/i915/display/intel_ddi.c ++++ b/drivers/gpu/drm/i915/display/intel_ddi.c +@@ -4648,6 +4648,7 @@ intel_ddi_init_hdmi_connector(struct intel_digital_port *dig_port) + + static bool intel_ddi_a_force_4_lanes(struct intel_digital_port *dig_port) + { ++ struct intel_display *display = to_intel_display(dig_port); + struct drm_i915_private *dev_priv = to_i915(dig_port->base.base.dev); + + if (dig_port->base.port != PORT_A) +@@ -4656,6 +4657,9 @@ static bool intel_ddi_a_force_4_lanes(struct intel_digital_port *dig_port) + if (dig_port->saved_port_bits & DDI_A_4_LANES) + return false; + ++ if (intel_has_quirk(display, QUIRK_DDI_A_FORCE_4_LANES)) ++ return true; ++ + /* Broxton/Geminilake: Bspec says that DDI_A_4_LANES is the only + * supported configuration + */ +diff --git a/drivers/gpu/drm/i915/display/intel_fbdev.c b/drivers/gpu/drm/i915/display/intel_fbdev.c +index bda702c2cab8..1647e141ae78 100644 +--- a/drivers/gpu/drm/i915/display/intel_fbdev.c ++++ b/drivers/gpu/drm/i915/display/intel_fbdev.c +@@ -196,10 +196,10 @@ static int intelfb_create(struct drm_fb_helper *helper, + return ret; + + if (intel_fb && +- (sizes->fb_width > intel_fb->base.width || +- sizes->fb_height > intel_fb->base.height)) { ++ (sizes->fb_width != intel_fb->base.width || ++ sizes->fb_height != intel_fb->base.height)) { + drm_dbg_kms(&dev_priv->drm, +- "BIOS fb too small (%dx%d), we require (%dx%d)," ++ "BIOS fb not valid (%dx%d), we require (%dx%d)," + " releasing it\n", + intel_fb->base.width, intel_fb->base.height, + sizes->fb_width, sizes->fb_height); +diff --git a/drivers/gpu/drm/i915/display/intel_quirks.c b/drivers/gpu/drm/i915/display/intel_quirks.c +index 14d5fefc9c5b..727639b8f6a6 100644 +--- a/drivers/gpu/drm/i915/display/intel_quirks.c ++++ b/drivers/gpu/drm/i915/display/intel_quirks.c +@@ -59,6 +59,18 @@ static void quirk_increase_ddi_disabled_time(struct intel_display *display) + drm_info(display->drm, "Applying Increase DDI Disabled quirk\n"); + } + ++/* ++ * In some cases, the firmware might not set the lane count to 4 (for example, ++ * when booting in some dual GPU Macs with the dGPU as the default GPU), this ++ * quirk is used to force it as otherwise it might not be possible to compute a ++ * valid link configuration. ++ */ ++static void quirk_ddi_a_force_4_lanes(struct intel_display *display) ++{ ++ intel_set_quirk(display, QUIRK_DDI_A_FORCE_4_LANES); ++ drm_info(display->drm, "Applying DDI A Forced 4 Lanes quirk\n"); ++} ++ + static void quirk_no_pps_backlight_power_hook(struct intel_display *display) + { + intel_set_quirk(display, QUIRK_NO_PPS_BACKLIGHT_POWER_HOOK); +@@ -201,6 +213,9 @@ static struct intel_quirk intel_quirks[] = { + { 0x3184, 0x1019, 0xa94d, quirk_increase_ddi_disabled_time }, + /* HP Notebook - 14-r206nv */ + { 0x0f31, 0x103c, 0x220f, quirk_invert_brightness }, ++ ++ /* Apple MacBookPro15,1 */ ++ { 0x3e9b, 0x106b, 0x0176, quirk_ddi_a_force_4_lanes }, + }; + + void intel_init_quirks(struct intel_display *display) +diff --git a/drivers/gpu/drm/i915/display/intel_quirks.h b/drivers/gpu/drm/i915/display/intel_quirks.h +index 151c8f4ae576..46e7feba88f4 100644 +--- a/drivers/gpu/drm/i915/display/intel_quirks.h ++++ b/drivers/gpu/drm/i915/display/intel_quirks.h +@@ -17,6 +17,7 @@ enum intel_quirk_id { + QUIRK_INVERT_BRIGHTNESS, + QUIRK_LVDS_SSC_DISABLE, + QUIRK_NO_PPS_BACKLIGHT_POWER_HOOK, ++ QUIRK_DDI_A_FORCE_4_LANES, + }; + + void intel_init_quirks(struct intel_display *display); +diff --git a/drivers/gpu/drm/tests/drm_format_helper_test.c b/drivers/gpu/drm/tests/drm_format_helper_test.c +index 08992636ec05..35cd3405d045 100644 +--- a/drivers/gpu/drm/tests/drm_format_helper_test.c ++++ b/drivers/gpu/drm/tests/drm_format_helper_test.c +@@ -60,6 +60,11 @@ struct convert_to_rgb888_result { + const u8 expected[TEST_BUF_SIZE]; + }; + ++struct convert_to_bgr888_result { ++ unsigned int dst_pitch; ++ const u8 expected[TEST_BUF_SIZE]; ++}; ++ + struct convert_to_argb8888_result { + unsigned int dst_pitch; + const u32 expected[TEST_BUF_SIZE]; +@@ -107,6 +112,7 @@ struct convert_xrgb8888_case { + struct convert_to_argb1555_result argb1555_result; + struct convert_to_rgba5551_result rgba5551_result; + struct convert_to_rgb888_result rgb888_result; ++ struct convert_to_bgr888_result bgr888_result; + struct convert_to_argb8888_result argb8888_result; + struct convert_to_xrgb2101010_result xrgb2101010_result; + struct convert_to_argb2101010_result argb2101010_result; +@@ -151,6 +157,10 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0x00, 0x00, 0xFF }, + }, ++ .bgr888_result = { ++ .dst_pitch = TEST_USE_DEFAULT_PITCH, ++ .expected = { 0xFF, 0x00, 0x00 }, ++ }, + .argb8888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0xFFFF0000 }, +@@ -217,6 +227,10 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0x00, 0x00, 0xFF }, + }, ++ .bgr888_result = { ++ .dst_pitch = TEST_USE_DEFAULT_PITCH, ++ .expected = { 0xFF, 0x00, 0x00 }, ++ }, + .argb8888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0xFFFF0000 }, +@@ -330,6 +344,15 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { + 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, + }, + }, ++ .bgr888_result = { ++ .dst_pitch = TEST_USE_DEFAULT_PITCH, ++ .expected = { ++ 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, ++ 0xFF, 0x00, 0x00, 0x00, 0xFF, 0x00, ++ 0x00, 0x00, 0xFF, 0xFF, 0x00, 0xFF, ++ 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, ++ }, ++ }, + .argb8888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { +@@ -468,6 +491,17 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, ++ .bgr888_result = { ++ .dst_pitch = 15, ++ .expected = { ++ 0x0E, 0x44, 0x9C, 0x11, 0x4D, 0x05, 0xA8, 0xF3, 0x03, ++ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ++ 0x6C, 0xF0, 0x73, 0x0E, 0x44, 0x9C, 0x11, 0x4D, 0x05, ++ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ++ 0xA8, 0x03, 0x03, 0x6C, 0xF0, 0x73, 0x0E, 0x44, 0x9C, ++ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ++ }, ++ }, + .argb8888_result = { + .dst_pitch = 20, + .expected = { +@@ -914,6 +948,52 @@ static void drm_test_fb_xrgb8888_to_rgb888(struct kunit *test) + KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); + } + ++static void drm_test_fb_xrgb8888_to_bgr888(struct kunit *test) ++{ ++ const struct convert_xrgb8888_case *params = test->param_value; ++ const struct convert_to_bgr888_result *result = ¶ms->bgr888_result; ++ size_t dst_size; ++ u8 *buf = NULL; ++ __le32 *xrgb8888 = NULL; ++ struct iosys_map dst, src; ++ ++ struct drm_framebuffer fb = { ++ .format = drm_format_info(DRM_FORMAT_XRGB8888), ++ .pitches = { params->pitch, 0, 0 }, ++ }; ++ ++ dst_size = conversion_buf_size(DRM_FORMAT_BGR888, result->dst_pitch, ++ ¶ms->clip, 0); ++ KUNIT_ASSERT_GT(test, dst_size, 0); ++ ++ buf = kunit_kzalloc(test, dst_size, GFP_KERNEL); ++ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, buf); ++ iosys_map_set_vaddr(&dst, buf); ++ ++ xrgb8888 = cpubuf_to_le32(test, params->xrgb8888, TEST_BUF_SIZE); ++ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, xrgb8888); ++ iosys_map_set_vaddr(&src, xrgb8888); ++ ++ /* ++ * BGR888 expected results are already in little-endian ++ * order, so there's no need to convert the test output. ++ */ ++ drm_fb_xrgb8888_to_bgr888(&dst, &result->dst_pitch, &src, &fb, ¶ms->clip, ++ &fmtcnv_state); ++ KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); ++ ++ buf = dst.vaddr; /* restore original value of buf */ ++ memset(buf, 0, dst_size); ++ ++ int blit_result = 0; ++ ++ blit_result = drm_fb_blit(&dst, &result->dst_pitch, DRM_FORMAT_BGR888, &src, &fb, ¶ms->clip, ++ &fmtcnv_state); ++ ++ KUNIT_EXPECT_FALSE(test, blit_result); ++ KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); ++} ++ + static void drm_test_fb_xrgb8888_to_argb8888(struct kunit *test) + { + const struct convert_xrgb8888_case *params = test->param_value; +@@ -1851,6 +1931,7 @@ static struct kunit_case drm_format_helper_test_cases[] = { + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb1555, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_rgba5551, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_rgb888, convert_xrgb8888_gen_params), ++ KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_bgr888, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb8888, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_xrgb2101010, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb2101010, convert_xrgb8888_gen_params), +diff --git a/drivers/gpu/drm/tiny/Kconfig b/drivers/gpu/drm/tiny/Kconfig +index f6889f649bc1..559a97bce12c 100644 +--- a/drivers/gpu/drm/tiny/Kconfig ++++ b/drivers/gpu/drm/tiny/Kconfig +@@ -1,5 +1,17 @@ + # SPDX-License-Identifier: GPL-2.0-only + ++config DRM_APPLETBDRM ++ tristate "DRM support for Apple Touch Bars" ++ depends on DRM && USB && MMU ++ select DRM_KMS_HELPER ++ select DRM_GEM_SHMEM_HELPER ++ help ++ Say Y here if you want support for the display of Touch Bars on x86 ++ MacBook Pros. ++ ++ To compile this driver as a module, choose M here: the ++ module will be called appletbdrm. ++ + config DRM_ARCPGU + tristate "ARC PGU" + depends on DRM && OF +diff --git a/drivers/gpu/drm/tiny/Makefile b/drivers/gpu/drm/tiny/Makefile +index 76dde89a044b..9a1b412e764a 100644 +--- a/drivers/gpu/drm/tiny/Makefile ++++ b/drivers/gpu/drm/tiny/Makefile +@@ -1,5 +1,6 @@ + # SPDX-License-Identifier: GPL-2.0-only + ++obj-$(CONFIG_DRM_APPLETBDRM) += appletbdrm.o + obj-$(CONFIG_DRM_ARCPGU) += arcpgu.o + obj-$(CONFIG_DRM_BOCHS) += bochs.o + obj-$(CONFIG_DRM_CIRRUS_QEMU) += cirrus.o +diff --git a/drivers/gpu/drm/tiny/appletbdrm.c b/drivers/gpu/drm/tiny/appletbdrm.c +new file mode 100644 +index 000000000000..b9440ce0064e +--- /dev/null ++++ b/drivers/gpu/drm/tiny/appletbdrm.c +@@ -0,0 +1,624 @@ ++// SPDX-License-Identifier: GPL-2.0 ++/* ++ * Apple Touch Bar DRM Driver ++ * ++ * Copyright (c) 2023 Kerem Karabay ++ */ ++ ++#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt ++ ++#include ++ ++#include ++#include ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#define _APPLETBDRM_FOURCC(s) (((s)[0] << 24) | ((s)[1] << 16) | ((s)[2] << 8) | (s)[3]) ++#define APPLETBDRM_FOURCC(s) _APPLETBDRM_FOURCC(#s) ++ ++#define APPLETBDRM_PIXEL_FORMAT APPLETBDRM_FOURCC(RGBA) /* The actual format is BGR888 */ ++#define APPLETBDRM_BITS_PER_PIXEL 24 ++ ++#define APPLETBDRM_MSG_CLEAR_DISPLAY APPLETBDRM_FOURCC(CLRD) ++#define APPLETBDRM_MSG_GET_INFORMATION APPLETBDRM_FOURCC(GINF) ++#define APPLETBDRM_MSG_UPDATE_COMPLETE APPLETBDRM_FOURCC(UDCL) ++#define APPLETBDRM_MSG_SIGNAL_READINESS APPLETBDRM_FOURCC(REDY) ++ ++#define APPLETBDRM_BULK_MSG_TIMEOUT 1000 ++ ++#define drm_to_adev(_drm) container_of(_drm, struct appletbdrm_device, drm) ++#define adev_to_udev(adev) interface_to_usbdev(to_usb_interface(adev->dev)) ++ ++struct appletbdrm_device { ++ struct device *dev; ++ ++ u8 in_ep; ++ u8 out_ep; ++ ++ u32 width; ++ u32 height; ++ ++ struct drm_device drm; ++ struct drm_display_mode mode; ++ struct drm_connector connector; ++ struct drm_simple_display_pipe pipe; ++ ++ bool readiness_signal_received; ++}; ++ ++struct appletbdrm_request_header { ++ __le16 unk_00; ++ __le16 unk_02; ++ __le32 unk_04; ++ __le32 unk_08; ++ __le32 size; ++} __packed; ++ ++struct appletbdrm_response_header { ++ u8 unk_00[16]; ++ u32 msg; ++} __packed; ++ ++struct appletbdrm_simple_request { ++ struct appletbdrm_request_header header; ++ u32 msg; ++ u8 unk_14[8]; ++ __le32 size; ++} __packed; ++ ++struct appletbdrm_information { ++ struct appletbdrm_response_header header; ++ u8 unk_14[12]; ++ __le32 width; ++ __le32 height; ++ u8 bits_per_pixel; ++ __le32 bytes_per_row; ++ __le32 orientation; ++ __le32 bitmap_info; ++ u32 pixel_format; ++ __le32 width_inches; /* floating point */ ++ __le32 height_inches; /* floating point */ ++} __packed; ++ ++struct appletbdrm_frame { ++ __le16 begin_x; ++ __le16 begin_y; ++ __le16 width; ++ __le16 height; ++ __le32 buf_size; ++ u8 buf[]; ++} __packed; ++ ++struct appletbdrm_fb_request_footer { ++ u8 unk_00[12]; ++ __le32 unk_0c; ++ u8 unk_10[12]; ++ __le32 unk_1c; ++ __le64 timestamp; ++ u8 unk_28[12]; ++ __le32 unk_34; ++ u8 unk_38[20]; ++ __le32 unk_4c; ++} __packed; ++ ++struct appletbdrm_fb_request { ++ struct appletbdrm_request_header header; ++ __le16 unk_10; ++ u8 msg_id; ++ u8 unk_13[29]; ++ /* ++ * Contents of `data`: ++ * - struct appletbdrm_frame frames[]; ++ * - struct appletbdrm_fb_request_footer footer; ++ * - padding to make the total size a multiple of 16 ++ */ ++ u8 data[]; ++} __packed; ++ ++struct appletbdrm_fb_request_response { ++ struct appletbdrm_response_header header; ++ u8 unk_14[12]; ++ __le64 timestamp; ++} __packed; ++ ++static int appletbdrm_send_request(struct appletbdrm_device *adev, ++ struct appletbdrm_request_header *request, size_t size) ++{ ++ struct usb_device *udev = adev_to_udev(adev); ++ struct drm_device *drm = &adev->drm; ++ int ret, actual_size; ++ ++ ret = usb_bulk_msg(udev, usb_sndbulkpipe(udev, adev->out_ep), ++ request, size, &actual_size, APPLETBDRM_BULK_MSG_TIMEOUT); ++ if (ret) { ++ drm_err(drm, "Failed to send message (%pe)\n", ERR_PTR(ret)); ++ return ret; ++ } ++ ++ if (actual_size != size) { ++ drm_err(drm, "Actual size (%d) doesn't match expected size (%lu)\n", ++ actual_size, size); ++ return -EIO; ++ } ++ ++ return ret; ++} ++ ++static int appletbdrm_read_response(struct appletbdrm_device *adev, ++ struct appletbdrm_response_header *response, ++ size_t size, u32 expected_response) ++{ ++ struct usb_device *udev = adev_to_udev(adev); ++ struct drm_device *drm = &adev->drm; ++ int ret, actual_size; ++ ++retry: ++ ret = usb_bulk_msg(udev, usb_rcvbulkpipe(udev, adev->in_ep), ++ response, size, &actual_size, APPLETBDRM_BULK_MSG_TIMEOUT); ++ if (ret) { ++ drm_err(drm, "Failed to read response (%pe)\n", ERR_PTR(ret)); ++ return ret; ++ } ++ ++ /* ++ * The device responds to the first request sent in a particular ++ * timeframe after the USB device configuration is set with a readiness ++ * signal, in which case the response should be read again ++ */ ++ if (response->msg == APPLETBDRM_MSG_SIGNAL_READINESS) { ++ if (!adev->readiness_signal_received) { ++ adev->readiness_signal_received = true; ++ goto retry; ++ } ++ ++ drm_err(drm, "Encountered unexpected readiness signal\n"); ++ return -EIO; ++ } ++ ++ if (actual_size != size) { ++ drm_err(drm, "Actual size (%d) doesn't match expected size (%lu)\n", ++ actual_size, size); ++ return -EIO; ++ } ++ ++ if (response->msg != expected_response) { ++ drm_err(drm, "Unexpected response from device (expected %p4ch found %p4ch)\n", ++ &expected_response, &response->msg); ++ return -EIO; ++ } ++ ++ return 0; ++} ++ ++static int appletbdrm_send_msg(struct appletbdrm_device *adev, u32 msg) ++{ ++ struct appletbdrm_simple_request *request; ++ int ret; ++ ++ request = kzalloc(sizeof(*request), GFP_KERNEL); ++ if (!request) ++ return -ENOMEM; ++ ++ request->header.unk_00 = cpu_to_le16(2); ++ request->header.unk_02 = cpu_to_le16(0x1512); ++ request->header.size = cpu_to_le32(sizeof(*request) - sizeof(request->header)); ++ request->msg = msg; ++ request->size = request->header.size; ++ ++ ret = appletbdrm_send_request(adev, &request->header, sizeof(*request)); ++ ++ kfree(request); ++ ++ return ret; ++} ++ ++static int appletbdrm_clear_display(struct appletbdrm_device *adev) ++{ ++ return appletbdrm_send_msg(adev, APPLETBDRM_MSG_CLEAR_DISPLAY); ++} ++ ++static int appletbdrm_signal_readiness(struct appletbdrm_device *adev) ++{ ++ return appletbdrm_send_msg(adev, APPLETBDRM_MSG_SIGNAL_READINESS); ++} ++ ++static int appletbdrm_get_information(struct appletbdrm_device *adev) ++{ ++ struct appletbdrm_information *info; ++ struct drm_device *drm = &adev->drm; ++ u8 bits_per_pixel; ++ u32 pixel_format; ++ int ret; ++ ++ info = kzalloc(sizeof(*info), GFP_KERNEL); ++ if (!info) ++ return -ENOMEM; ++ ++ ret = appletbdrm_send_msg(adev, APPLETBDRM_MSG_GET_INFORMATION); ++ if (ret) ++ return ret; ++ ++ ret = appletbdrm_read_response(adev, &info->header, sizeof(*info), ++ APPLETBDRM_MSG_GET_INFORMATION); ++ if (ret) ++ goto free_info; ++ ++ bits_per_pixel = info->bits_per_pixel; ++ pixel_format = get_unaligned(&info->pixel_format); ++ ++ adev->width = get_unaligned_le32(&info->width); ++ adev->height = get_unaligned_le32(&info->height); ++ ++ if (bits_per_pixel != APPLETBDRM_BITS_PER_PIXEL) { ++ drm_err(drm, "Encountered unexpected bits per pixel value (%d)\n", bits_per_pixel); ++ ret = -EINVAL; ++ goto free_info; ++ } ++ ++ if (pixel_format != APPLETBDRM_PIXEL_FORMAT) { ++ drm_err(drm, "Encountered unknown pixel format (%p4ch)\n", &pixel_format); ++ ret = -EINVAL; ++ goto free_info; ++ } ++ ++free_info: ++ kfree(info); ++ ++ return ret; ++} ++ ++static u32 rect_size(struct drm_rect *rect) ++{ ++ return drm_rect_width(rect) * drm_rect_height(rect) * (APPLETBDRM_BITS_PER_PIXEL / 8); ++} ++ ++static int appletbdrm_flush_damage(struct appletbdrm_device *adev, ++ struct drm_plane_state *old_state, ++ struct drm_plane_state *state) ++{ ++ struct drm_shadow_plane_state *shadow_plane_state = to_drm_shadow_plane_state(state); ++ struct appletbdrm_fb_request_response *response; ++ struct appletbdrm_fb_request_footer *footer; ++ struct drm_atomic_helper_damage_iter iter; ++ struct drm_framebuffer *fb = state->fb; ++ struct appletbdrm_fb_request *request; ++ struct drm_device *drm = &adev->drm; ++ struct appletbdrm_frame *frame; ++ u64 timestamp = ktime_get_ns(); ++ struct drm_rect damage; ++ size_t frames_size = 0; ++ size_t request_size; ++ int ret; ++ ++ drm_atomic_helper_damage_iter_init(&iter, old_state, state); ++ drm_atomic_for_each_plane_damage(&iter, &damage) { ++ frames_size += struct_size(frame, buf, rect_size(&damage)); ++ } ++ ++ if (!frames_size) ++ return 0; ++ ++ request_size = ALIGN(sizeof(*request) + frames_size + sizeof(*footer), 16); ++ ++ request = kzalloc(request_size, GFP_KERNEL); ++ if (!request) ++ return -ENOMEM; ++ ++ response = kzalloc(sizeof(*response), GFP_KERNEL); ++ if (!response) { ++ ret = -ENOMEM; ++ goto free_request; ++ } ++ ++ ret = drm_gem_fb_begin_cpu_access(fb, DMA_FROM_DEVICE); ++ if (ret) { ++ drm_err(drm, "Failed to start CPU framebuffer access (%pe)\n", ERR_PTR(ret)); ++ goto free_response; ++ } ++ ++ request->header.unk_00 = cpu_to_le16(2); ++ request->header.unk_02 = cpu_to_le16(0x12); ++ request->header.unk_04 = cpu_to_le32(9); ++ request->header.size = cpu_to_le32(request_size - sizeof(request->header)); ++ request->unk_10 = cpu_to_le16(1); ++ request->msg_id = timestamp & 0xff; ++ ++ frame = (struct appletbdrm_frame *)request->data; ++ ++ drm_atomic_helper_damage_iter_init(&iter, old_state, state); ++ drm_atomic_for_each_plane_damage(&iter, &damage) { ++ struct iosys_map dst = IOSYS_MAP_INIT_VADDR(frame->buf); ++ u32 buf_size = rect_size(&damage); ++ ++ /* ++ * The coordinates need to be translated to the coordinate ++ * system the device expects, see the comment in ++ * appletbdrm_setup_mode_config ++ */ ++ frame->begin_x = cpu_to_le16(damage.y1); ++ frame->begin_y = cpu_to_le16(adev->height - damage.x2); ++ frame->width = cpu_to_le16(drm_rect_height(&damage)); ++ frame->height = cpu_to_le16(drm_rect_width(&damage)); ++ frame->buf_size = cpu_to_le32(buf_size); ++ ++ ret = drm_fb_blit(&dst, NULL, DRM_FORMAT_BGR888, ++ &shadow_plane_state->data[0], fb, &damage, &shadow_plane_state->fmtcnv_state); ++ if (ret) { ++ drm_err(drm, "Failed to copy damage clip (%pe)\n", ERR_PTR(ret)); ++ goto end_fb_cpu_access; ++ } ++ ++ frame = (void *)frame + struct_size(frame, buf, buf_size); ++ } ++ ++ footer = (struct appletbdrm_fb_request_footer *)&request->data[frames_size]; ++ ++ footer->unk_0c = cpu_to_le32(0xfffe); ++ footer->unk_1c = cpu_to_le32(0x80001); ++ footer->unk_34 = cpu_to_le32(0x80002); ++ footer->unk_4c = cpu_to_le32(0xffff); ++ footer->timestamp = cpu_to_le64(timestamp); ++ ++ ret = appletbdrm_send_request(adev, &request->header, request_size); ++ if (ret) ++ goto end_fb_cpu_access; ++ ++ ret = appletbdrm_read_response(adev, &response->header, sizeof(*response), ++ APPLETBDRM_MSG_UPDATE_COMPLETE); ++ if (ret) ++ goto end_fb_cpu_access; ++ ++ if (response->timestamp != footer->timestamp) { ++ drm_err(drm, "Response timestamp (%llu) doesn't match request timestamp (%llu)\n", ++ le64_to_cpu(response->timestamp), timestamp); ++ goto end_fb_cpu_access; ++ } ++ ++end_fb_cpu_access: ++ drm_gem_fb_end_cpu_access(fb, DMA_FROM_DEVICE); ++free_response: ++ kfree(response); ++free_request: ++ kfree(request); ++ ++ return ret; ++} ++ ++static int appletbdrm_connector_helper_get_modes(struct drm_connector *connector) ++{ ++ struct appletbdrm_device *adev = drm_to_adev(connector->dev); ++ ++ return drm_connector_helper_get_modes_fixed(connector, &adev->mode); ++} ++ ++static enum drm_mode_status appletbdrm_pipe_mode_valid(struct drm_simple_display_pipe *pipe, ++ const struct drm_display_mode *mode) ++{ ++ struct drm_crtc *crtc = &pipe->crtc; ++ struct appletbdrm_device *adev = drm_to_adev(crtc->dev); ++ ++ return drm_crtc_helper_mode_valid_fixed(crtc, mode, &adev->mode); ++} ++ ++static void appletbdrm_pipe_disable(struct drm_simple_display_pipe *pipe) ++{ ++ struct appletbdrm_device *adev = drm_to_adev(pipe->crtc.dev); ++ int idx; ++ ++ if (!drm_dev_enter(&adev->drm, &idx)) ++ return; ++ ++ appletbdrm_clear_display(adev); ++ ++ drm_dev_exit(idx); ++} ++ ++static void appletbdrm_pipe_update(struct drm_simple_display_pipe *pipe, ++ struct drm_plane_state *old_state) ++{ ++ struct drm_crtc *crtc = &pipe->crtc; ++ struct appletbdrm_device *adev = drm_to_adev(crtc->dev); ++ int idx; ++ ++ if (!crtc->state->active || !drm_dev_enter(&adev->drm, &idx)) ++ return; ++ ++ appletbdrm_flush_damage(adev, old_state, pipe->plane.state); ++ ++ drm_dev_exit(idx); ++} ++ ++static const u32 appletbdrm_formats[] = { ++ DRM_FORMAT_BGR888, ++ DRM_FORMAT_XRGB8888, /* emulated */ ++}; ++ ++static const struct drm_mode_config_funcs appletbdrm_mode_config_funcs = { ++ .fb_create = drm_gem_fb_create_with_dirty, ++ .atomic_check = drm_atomic_helper_check, ++ .atomic_commit = drm_atomic_helper_commit, ++}; ++ ++static const struct drm_connector_funcs appletbdrm_connector_funcs = { ++ .reset = drm_atomic_helper_connector_reset, ++ .destroy = drm_connector_cleanup, ++ .fill_modes = drm_helper_probe_single_connector_modes, ++ .atomic_destroy_state = drm_atomic_helper_connector_destroy_state, ++ .atomic_duplicate_state = drm_atomic_helper_connector_duplicate_state, ++}; ++ ++static const struct drm_connector_helper_funcs appletbdrm_connector_helper_funcs = { ++ .get_modes = appletbdrm_connector_helper_get_modes, ++}; ++ ++static const struct drm_simple_display_pipe_funcs appletbdrm_pipe_funcs = { ++ DRM_GEM_SIMPLE_DISPLAY_PIPE_SHADOW_PLANE_FUNCS, ++ .update = appletbdrm_pipe_update, ++ .disable = appletbdrm_pipe_disable, ++ .mode_valid = appletbdrm_pipe_mode_valid, ++}; ++ ++DEFINE_DRM_GEM_FOPS(appletbdrm_drm_fops); ++ ++static const struct drm_driver appletbdrm_drm_driver = { ++ DRM_GEM_SHMEM_DRIVER_OPS, ++ .name = "appletbdrm", ++ .desc = "Apple Touch Bar DRM Driver", ++ .date = "20230910", ++ .major = 1, ++ .minor = 0, ++ .driver_features = DRIVER_MODESET | DRIVER_GEM | DRIVER_ATOMIC, ++ .fops = &appletbdrm_drm_fops, ++}; ++ ++static int appletbdrm_setup_mode_config(struct appletbdrm_device *adev) ++{ ++ struct drm_connector *connector = &adev->connector; ++ struct drm_device *drm = &adev->drm; ++ struct device *dev = adev->dev; ++ int ret; ++ ++ ret = drmm_mode_config_init(drm); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to initialize mode configuration\n"); ++ ++ /* ++ * The coordinate system used by the device is different from the ++ * coordinate system of the framebuffer in that the x and y axes are ++ * swapped, and that the y axis is inverted; so what the device reports ++ * as the height is actually the width of the framebuffer and vice ++ * versa ++ */ ++ drm->mode_config.min_width = 0; ++ drm->mode_config.min_height = 0; ++ drm->mode_config.max_width = max(adev->height, DRM_SHADOW_PLANE_MAX_WIDTH); ++ drm->mode_config.max_height = max(adev->width, DRM_SHADOW_PLANE_MAX_HEIGHT); ++ drm->mode_config.preferred_depth = APPLETBDRM_BITS_PER_PIXEL; ++ drm->mode_config.funcs = &appletbdrm_mode_config_funcs; ++ ++ adev->mode = (struct drm_display_mode) { ++ DRM_MODE_INIT(60, adev->height, adev->width, ++ DRM_MODE_RES_MM(adev->height, 218), ++ DRM_MODE_RES_MM(adev->width, 218)) ++ }; ++ ++ ret = drm_connector_init(drm, connector, ++ &appletbdrm_connector_funcs, DRM_MODE_CONNECTOR_USB); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to initialize connector\n"); ++ ++ drm_connector_helper_add(connector, &appletbdrm_connector_helper_funcs); ++ ++ ret = drm_connector_set_panel_orientation(connector, ++ DRM_MODE_PANEL_ORIENTATION_RIGHT_UP); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to set panel orientation\n"); ++ ++ connector->display_info.non_desktop = true; ++ ret = drm_object_property_set_value(&connector->base, ++ drm->mode_config.non_desktop_property, true); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to set non-desktop property\n"); ++ ++ ret = drm_simple_display_pipe_init(drm, &adev->pipe, &appletbdrm_pipe_funcs, ++ appletbdrm_formats, ARRAY_SIZE(appletbdrm_formats), ++ NULL, &adev->connector); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to initialize simple display pipe\n"); ++ ++ drm_plane_enable_fb_damage_clips(&adev->pipe.plane); ++ ++ drm_mode_config_reset(drm); ++ ++ ret = drm_dev_register(drm, 0); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to register DRM device\n"); ++ ++ return 0; ++} ++ ++static int appletbdrm_probe(struct usb_interface *intf, ++ const struct usb_device_id *id) ++{ ++ struct usb_endpoint_descriptor *bulk_in, *bulk_out; ++ struct device *dev = &intf->dev; ++ struct appletbdrm_device *adev; ++ int ret; ++ ++ ret = usb_find_common_endpoints(intf->cur_altsetting, &bulk_in, &bulk_out, NULL, NULL); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to find bulk endpoints\n"); ++ ++ adev = devm_drm_dev_alloc(dev, &appletbdrm_drm_driver, struct appletbdrm_device, drm); ++ if (IS_ERR(adev)) ++ return PTR_ERR(adev); ++ ++ adev->dev = dev; ++ adev->in_ep = bulk_in->bEndpointAddress; ++ adev->out_ep = bulk_out->bEndpointAddress; ++ ++ usb_set_intfdata(intf, adev); ++ ++ ret = appletbdrm_get_information(adev); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to get display information\n"); ++ ++ ret = appletbdrm_signal_readiness(adev); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to signal readiness\n"); ++ ++ ret = appletbdrm_clear_display(adev); ++ if (ret) ++ return dev_err_probe(dev, ret, "Failed to clear display\n"); ++ ++ return appletbdrm_setup_mode_config(adev); ++} ++ ++static void appletbdrm_disconnect(struct usb_interface *intf) ++{ ++ struct appletbdrm_device *adev = usb_get_intfdata(intf); ++ struct drm_device *drm = &adev->drm; ++ ++ drm_dev_unplug(drm); ++ drm_atomic_helper_shutdown(drm); ++} ++ ++static void appletbdrm_shutdown(struct usb_interface *intf) ++{ ++ struct appletbdrm_device *adev = usb_get_intfdata(intf); ++ ++ /* ++ * The framebuffer needs to be cleared on shutdown since its content ++ * persists across boots ++ */ ++ drm_atomic_helper_shutdown(&adev->drm); ++} ++ ++static const struct usb_device_id appletbdrm_usb_id_table[] = { ++ { USB_DEVICE_INTERFACE_CLASS(0x05ac, 0x8302, USB_CLASS_AUDIO_VIDEO) }, ++ {} ++}; ++MODULE_DEVICE_TABLE(usb, appletbdrm_usb_id_table); ++ ++static struct usb_driver appletbdrm_usb_driver = { ++ .name = "appletbdrm", ++ .probe = appletbdrm_probe, ++ .disconnect = appletbdrm_disconnect, ++ .shutdown = appletbdrm_shutdown, ++ .id_table = appletbdrm_usb_id_table, ++}; ++module_usb_driver(appletbdrm_usb_driver); ++ ++MODULE_AUTHOR("Kerem Karabay "); ++MODULE_DESCRIPTION("Apple Touch Bar DRM Driver"); ++MODULE_LICENSE("GPL"); +diff --git a/drivers/gpu/vga/vga_switcheroo.c b/drivers/gpu/vga/vga_switcheroo.c +index 365e6ddbe90f..cf357cd3389d 100644 +--- a/drivers/gpu/vga/vga_switcheroo.c ++++ b/drivers/gpu/vga/vga_switcheroo.c +@@ -438,12 +438,7 @@ find_active_client(struct list_head *head) + bool vga_switcheroo_client_probe_defer(struct pci_dev *pdev) + { + if ((pdev->class >> 16) == PCI_BASE_CLASS_DISPLAY) { +- /* +- * apple-gmux is needed on pre-retina MacBook Pro +- * to probe the panel if pdev is the inactive GPU. +- */ +- if (apple_gmux_present() && pdev != vga_default_device() && +- !vgasr_priv.handler_flags) ++ if (apple_gmux_present() && !vgasr_priv.handler_flags) + return true; + } + +diff --git a/drivers/hid/Kconfig b/drivers/hid/Kconfig +index 08446c89eff6..35ef5d4ef068 100644 +--- a/drivers/hid/Kconfig ++++ b/drivers/hid/Kconfig +@@ -148,6 +148,27 @@ config HID_APPLEIR + + Say Y here if you want support for Apple infrared remote control. + ++config HID_APPLETB_BL ++ tristate "Apple Touch Bar Backlight" ++ depends on BACKLIGHT_CLASS_DEVICE ++ help ++ Say Y here if you want support for the backlight of Touch Bars on x86 ++ MacBook Pros. ++ ++ To compile this driver as a module, choose M here: the ++ module will be called hid-appletb-bl. ++ ++config HID_APPLETB_KBD ++ tristate "Apple Touch Bar Keyboard Mode" ++ depends on USB_HID ++ help ++ Say Y here if you want support for the keyboard mode (escape, ++ function, media and brightness keys) of Touch Bars on x86 MacBook ++ Pros. ++ ++ To compile this driver as a module, choose M here: the ++ module will be called hid-appletb-kbd. ++ + config HID_ASUS + tristate "Asus" + depends on USB_HID +@@ -723,6 +744,7 @@ config HID_MULTITOUCH + Say Y here if you have one of the following devices: + - 3M PCT touch screens + - ActionStar dual touch panels ++ - Touch Bars on x86 MacBook Pros + - Atmel panels + - Cando dual touch panels + - Chunghwa panels +diff --git a/drivers/hid/Makefile b/drivers/hid/Makefile +index ce71b53ea6c5..fecec1d61393 100644 +--- a/drivers/hid/Makefile ++++ b/drivers/hid/Makefile +@@ -29,6 +29,8 @@ obj-$(CONFIG_HID_ALPS) += hid-alps.o + obj-$(CONFIG_HID_ACRUX) += hid-axff.o + obj-$(CONFIG_HID_APPLE) += hid-apple.o + obj-$(CONFIG_HID_APPLEIR) += hid-appleir.o ++obj-$(CONFIG_HID_APPLETB_BL) += hid-appletb-bl.o ++obj-$(CONFIG_HID_APPLETB_KBD) += hid-appletb-kbd.o + obj-$(CONFIG_HID_CREATIVE_SB0540) += hid-creative-sb0540.o + obj-$(CONFIG_HID_ASUS) += hid-asus.o + obj-$(CONFIG_HID_AUREAL) += hid-aureal.o +diff --git a/drivers/hid/hid-apple.c b/drivers/hid/hid-apple.c +index bd022e004356..6dedb84d7cc3 100644 +--- a/drivers/hid/hid-apple.c ++++ b/drivers/hid/hid-apple.c +@@ -8,6 +8,8 @@ + * Copyright (c) 2006-2007 Jiri Kosina + * Copyright (c) 2008 Jiri Slaby + * Copyright (c) 2019 Paul Pawlowski ++ * Copyright (c) 2023 Orlando Chamberlain ++ * Copyright (c) 2024 Aditya Garg + */ + + /* +@@ -23,6 +25,7 @@ + #include + #include + #include ++#include + + #include "hid-ids.h" + +@@ -38,12 +41,17 @@ + #define APPLE_RDESC_BATTERY BIT(9) + #define APPLE_BACKLIGHT_CTL BIT(10) + #define APPLE_IS_NON_APPLE BIT(11) ++#define APPLE_MAGIC_BACKLIGHT BIT(12) + + #define APPLE_FLAG_FKEY 0x01 + + #define HID_COUNTRY_INTERNATIONAL_ISO 13 + #define APPLE_BATTERY_TIMEOUT_MS 60000 + ++#define HID_USAGE_MAGIC_BL 0xff00000f ++#define APPLE_MAGIC_REPORT_ID_POWER 3 ++#define APPLE_MAGIC_REPORT_ID_BRIGHTNESS 1 ++ + static unsigned int fnmode = 3; + module_param(fnmode, uint, 0644); + MODULE_PARM_DESC(fnmode, "Mode of fn key on Apple keyboards (0 = disabled, " +@@ -81,6 +89,12 @@ struct apple_sc_backlight { + struct hid_device *hdev; + }; + ++struct apple_magic_backlight { ++ struct led_classdev cdev; ++ struct hid_report *brightness; ++ struct hid_report *power; ++}; ++ + struct apple_sc { + struct hid_device *hdev; + unsigned long quirks; +@@ -822,6 +836,66 @@ static int apple_backlight_init(struct hid_device *hdev) + return ret; + } + ++static void apple_magic_backlight_report_set(struct hid_report *rep, s32 value, u8 rate) ++{ ++ rep->field[0]->value[0] = value; ++ rep->field[1]->value[0] = 0x5e; /* Mimic Windows */ ++ rep->field[1]->value[0] |= rate << 8; ++ ++ hid_hw_request(rep->device, rep, HID_REQ_SET_REPORT); ++} ++ ++static void apple_magic_backlight_set(struct apple_magic_backlight *backlight, ++ int brightness, char rate) ++{ ++ apple_magic_backlight_report_set(backlight->power, brightness ? 1 : 0, rate); ++ if (brightness) ++ apple_magic_backlight_report_set(backlight->brightness, brightness, rate); ++} ++ ++static int apple_magic_backlight_led_set(struct led_classdev *led_cdev, ++ enum led_brightness brightness) ++{ ++ struct apple_magic_backlight *backlight = container_of(led_cdev, ++ struct apple_magic_backlight, cdev); ++ ++ apple_magic_backlight_set(backlight, brightness, 1); ++ return 0; ++} ++ ++static int apple_magic_backlight_init(struct hid_device *hdev) ++{ ++ struct apple_magic_backlight *backlight; ++ struct hid_report_enum *report_enum; ++ ++ /* ++ * Ensure this usb endpoint is for the keyboard backlight, not touchbar ++ * backlight. ++ */ ++ if (hdev->collection[0].usage != HID_USAGE_MAGIC_BL) ++ return -ENODEV; ++ ++ backlight = devm_kzalloc(&hdev->dev, sizeof(*backlight), GFP_KERNEL); ++ if (!backlight) ++ return -ENOMEM; ++ ++ report_enum = &hdev->report_enum[HID_FEATURE_REPORT]; ++ backlight->brightness = report_enum->report_id_hash[APPLE_MAGIC_REPORT_ID_BRIGHTNESS]; ++ backlight->power = report_enum->report_id_hash[APPLE_MAGIC_REPORT_ID_POWER]; ++ ++ if (!backlight->brightness || !backlight->power) ++ return -ENODEV; ++ ++ backlight->cdev.name = ":white:" LED_FUNCTION_KBD_BACKLIGHT; ++ backlight->cdev.max_brightness = backlight->brightness->field[0]->logical_maximum; ++ backlight->cdev.brightness_set_blocking = apple_magic_backlight_led_set; ++ ++ apple_magic_backlight_set(backlight, 0, 0); ++ ++ return devm_led_classdev_register(&hdev->dev, &backlight->cdev); ++ ++} ++ + static int apple_probe(struct hid_device *hdev, + const struct hid_device_id *id) + { +@@ -860,7 +934,18 @@ static int apple_probe(struct hid_device *hdev, + if (quirks & APPLE_BACKLIGHT_CTL) + apple_backlight_init(hdev); + ++ if (quirks & APPLE_MAGIC_BACKLIGHT) { ++ ret = apple_magic_backlight_init(hdev); ++ if (ret) ++ goto out_err; ++ } ++ + return 0; ++ ++out_err: ++ del_timer_sync(&asc->battery_timer); ++ hid_hw_stop(hdev); ++ return ret; + } + + static void apple_remove(struct hid_device *hdev) +@@ -1073,6 +1158,8 @@ static const struct hid_device_id apple_devices[] = { + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK | APPLE_RDESC_BATTERY }, + { HID_BLUETOOTH_DEVICE(BT_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_NUMPAD_2021), + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, ++ { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT), ++ .driver_data = APPLE_MAGIC_BACKLIGHT }, + + { } + }; +diff --git a/drivers/hid/hid-appletb-bl.c b/drivers/hid/hid-appletb-bl.c +new file mode 100644 +index 000000000000..00bbe45df4fa +--- /dev/null ++++ b/drivers/hid/hid-appletb-bl.c +@@ -0,0 +1,206 @@ ++// SPDX-License-Identifier: GPL-2.0 ++/* ++ * Apple Touch Bar Backlight Driver ++ * ++ * Copyright (c) 2017-2018 Ronald Tschalär ++ * Copyright (c) 2022-2023 Kerem Karabay ++ */ ++ ++#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt ++ ++#include ++#include ++ ++#include "hid-ids.h" ++ ++#define APPLETB_BL_ON 1 ++#define APPLETB_BL_DIM 3 ++#define APPLETB_BL_OFF 4 ++ ++#define HID_UP_APPLEVENDOR_TB_BL 0xff120000 ++ ++#define HID_VD_APPLE_TB_BRIGHTNESS 0xff120001 ++#define HID_USAGE_AUX1 0xff120020 ++#define HID_USAGE_BRIGHTNESS 0xff120021 ++ ++static int appletb_bl_def_brightness = 2; ++module_param_named(brightness, appletb_bl_def_brightness, int, 0444); ++MODULE_PARM_DESC(brightness, "Default brightness:\n" ++ " 0 - Touchbar is off\n" ++ " 1 - Dim brightness\n" ++ " [2] - Full brightness"); ++ ++struct appletb_bl { ++ struct hid_field *aux1_field, *brightness_field; ++ struct backlight_device *bdev; ++ ++ bool full_on; ++}; ++ ++const u8 appletb_bl_brightness_map[] = { ++ APPLETB_BL_OFF, ++ APPLETB_BL_DIM, ++ APPLETB_BL_ON ++}; ++ ++static int appletb_bl_set_brightness(struct appletb_bl *bl, u8 brightness) ++{ ++ struct hid_report *report = bl->brightness_field->report; ++ struct hid_device *hdev = report->device; ++ int ret; ++ ++ ret = hid_set_field(bl->aux1_field, 0, 1); ++ if (ret) { ++ hid_err(hdev, "Failed to set auxiliary field (%pe)\n", ERR_PTR(ret)); ++ return ret; ++ } ++ ++ ret = hid_set_field(bl->brightness_field, 0, brightness); ++ if (ret) { ++ hid_err(hdev, "Failed to set brightness field (%pe)\n", ERR_PTR(ret)); ++ return ret; ++ } ++ ++ if (!bl->full_on) { ++ ret = hid_hw_power(hdev, PM_HINT_FULLON); ++ if (ret < 0) { ++ hid_err(hdev, "Device didn't power on (%pe)\n", ERR_PTR(ret)); ++ return ret; ++ } ++ ++ bl->full_on = true; ++ } ++ ++ hid_hw_request(hdev, report, HID_REQ_SET_REPORT); ++ ++ if (brightness == APPLETB_BL_OFF) { ++ hid_hw_power(hdev, PM_HINT_NORMAL); ++ bl->full_on = false; ++ } ++ ++ return 0; ++} ++ ++static int appletb_bl_update_status(struct backlight_device *bdev) ++{ ++ struct appletb_bl *bl = bl_get_data(bdev); ++ u16 brightness; ++ ++ if (bdev->props.state & BL_CORE_SUSPENDED) ++ brightness = 0; ++ else ++ brightness = backlight_get_brightness(bdev); ++ ++ return appletb_bl_set_brightness(bl, appletb_bl_brightness_map[brightness]); ++} ++ ++static const struct backlight_ops appletb_bl_backlight_ops = { ++ .options = BL_CORE_SUSPENDRESUME, ++ .update_status = appletb_bl_update_status, ++}; ++ ++static int appletb_bl_probe(struct hid_device *hdev, const struct hid_device_id *id) ++{ ++ struct hid_field *aux1_field, *brightness_field; ++ struct backlight_properties bl_props = { 0 }; ++ struct device *dev = &hdev->dev; ++ struct appletb_bl *bl; ++ int ret; ++ ++ ret = hid_parse(hdev); ++ if (ret) ++ return dev_err_probe(dev, ret, "HID parse failed\n"); ++ ++ aux1_field = hid_find_field(hdev, HID_FEATURE_REPORT, ++ HID_VD_APPLE_TB_BRIGHTNESS, HID_USAGE_AUX1); ++ ++ brightness_field = hid_find_field(hdev, HID_FEATURE_REPORT, ++ HID_VD_APPLE_TB_BRIGHTNESS, HID_USAGE_BRIGHTNESS); ++ ++ if (!aux1_field || !brightness_field) ++ return -ENODEV; ++ ++ if (aux1_field->report != brightness_field->report) ++ return dev_err_probe(dev, -ENODEV, "Encountered unexpected report structure\n"); ++ ++ bl = devm_kzalloc(dev, sizeof(*bl), GFP_KERNEL); ++ if (!bl) ++ return -ENOMEM; ++ ++ ret = hid_hw_start(hdev, HID_CONNECT_DRIVER); ++ if (ret) ++ return dev_err_probe(dev, ret, "HID hardware start failed\n"); ++ ++ ret = hid_hw_open(hdev); ++ if (ret) { ++ dev_err_probe(dev, ret, "HID hardware open failed\n"); ++ goto stop_hw; ++ } ++ ++ bl->aux1_field = aux1_field; ++ bl->brightness_field = brightness_field; ++ ++ if (appletb_bl_def_brightness == 0) ++ ret = appletb_bl_set_brightness(bl, APPLETB_BL_OFF); ++ else if (appletb_bl_def_brightness == 1) ++ ret = appletb_bl_set_brightness(bl, APPLETB_BL_DIM); ++ else ++ ret = appletb_bl_set_brightness(bl, APPLETB_BL_ON); ++ ++ if (ret) { ++ dev_err_probe(dev, ret, "Failed to set touch bar brightness to off\n"); ++ goto close_hw; ++ } ++ ++ bl_props.type = BACKLIGHT_RAW; ++ bl_props.max_brightness = ARRAY_SIZE(appletb_bl_brightness_map) - 1; ++ ++ bl->bdev = devm_backlight_device_register(dev, "appletb_backlight", dev, bl, ++ &appletb_bl_backlight_ops, &bl_props); ++ if (IS_ERR(bl->bdev)) { ++ ret = PTR_ERR(bl->bdev); ++ dev_err_probe(dev, ret, "Failed to register backlight device\n"); ++ goto close_hw; ++ } ++ ++ hid_set_drvdata(hdev, bl); ++ ++ return 0; ++ ++close_hw: ++ hid_hw_close(hdev); ++stop_hw: ++ hid_hw_stop(hdev); ++ ++ return ret; ++} ++ ++static void appletb_bl_remove(struct hid_device *hdev) ++{ ++ struct appletb_bl *bl = hid_get_drvdata(hdev); ++ ++ appletb_bl_set_brightness(bl, APPLETB_BL_OFF); ++ ++ hid_hw_close(hdev); ++ hid_hw_stop(hdev); ++} ++ ++static const struct hid_device_id appletb_bl_hid_ids[] = { ++ /* MacBook Pro's 2018, 2019, with T2 chip: iBridge DFR Brightness */ ++ { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, ++ { } ++}; ++MODULE_DEVICE_TABLE(hid, appletb_bl_hid_ids); ++ ++static struct hid_driver appletb_bl_hid_driver = { ++ .name = "hid-appletb-bl", ++ .id_table = appletb_bl_hid_ids, ++ .probe = appletb_bl_probe, ++ .remove = appletb_bl_remove, ++}; ++module_hid_driver(appletb_bl_hid_driver); ++ ++MODULE_AUTHOR("Ronald Tschalär"); ++MODULE_AUTHOR("Kerem Karabay "); ++MODULE_DESCRIPTION("MacBookPro Touch Bar Backlight Driver"); ++MODULE_LICENSE("GPL"); +diff --git a/drivers/hid/hid-appletb-kbd.c b/drivers/hid/hid-appletb-kbd.c +new file mode 100644 +index 000000000000..ec8051dcf4db +--- /dev/null ++++ b/drivers/hid/hid-appletb-kbd.c +@@ -0,0 +1,432 @@ ++// SPDX-License-Identifier: GPL-2.0 ++/* ++ * Apple Touch Bar Keyboard Mode Driver ++ * ++ * Copyright (c) 2017-2018 Ronald Tschalär ++ * Copyright (c) 2022-2023 Kerem Karabay ++ * Copyright (c) 2024 Aditya Garg ++ */ ++ ++#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#include "hid-ids.h" ++ ++#define APPLETB_KBD_MODE_ESC 0 ++#define APPLETB_KBD_MODE_FN 1 ++#define APPLETB_KBD_MODE_SPCL 2 ++#define APPLETB_KBD_MODE_OFF 3 ++#define APPLETB_KBD_MODE_MAX APPLETB_KBD_MODE_OFF ++ ++#define APPLETB_DEVID_KEYBOARD 1 ++ ++#define HID_USAGE_MODE 0x00ff0004 ++ ++static int appletb_tb_def_mode = APPLETB_KBD_MODE_FN; ++module_param_named(mode, appletb_tb_def_mode, int, 0444); ++MODULE_PARM_DESC(mode, "Default touchbar mode:\n" ++ " 0 - escape key only\n" ++ " [1] - function-keys only\n" ++ " 2 - special keys only"); ++ ++static bool appletb_tb_fn_toggle = true; ++module_param_named(fntoggle, appletb_tb_fn_toggle, bool, 0644); ++MODULE_PARM_DESC(fntoggle, "Switch between Fn and media controls on pressing Fn key"); ++ ++struct appletb_kbd { ++ struct hid_field *mode_field; ++ ++ u8 saved_mode; ++ u8 current_mode; ++ struct input_handler inp_handler; ++ struct input_handle kbd_handle; ++ ++}; ++ ++static const struct key_entry appletb_kbd_keymap[] = { ++ { KE_KEY, KEY_ESC, { KEY_ESC } }, ++ { KE_KEY, KEY_F1, { KEY_BRIGHTNESSDOWN } }, ++ { KE_KEY, KEY_F2, { KEY_BRIGHTNESSUP } }, ++ { KE_KEY, KEY_F3, { KEY_RESERVED } }, ++ { KE_KEY, KEY_F4, { KEY_RESERVED } }, ++ { KE_KEY, KEY_F5, { KEY_KBDILLUMDOWN } }, ++ { KE_KEY, KEY_F6, { KEY_KBDILLUMUP } }, ++ { KE_KEY, KEY_F7, { KEY_PREVIOUSSONG } }, ++ { KE_KEY, KEY_F8, { KEY_PLAYPAUSE } }, ++ { KE_KEY, KEY_F9, { KEY_NEXTSONG } }, ++ { KE_KEY, KEY_F10, { KEY_MUTE } }, ++ { KE_KEY, KEY_F11, { KEY_VOLUMEDOWN } }, ++ { KE_KEY, KEY_F12, { KEY_VOLUMEUP } }, ++ { KE_END, 0 } ++}; ++ ++static int appletb_kbd_set_mode(struct appletb_kbd *kbd, u8 mode) ++{ ++ struct hid_report *report = kbd->mode_field->report; ++ struct hid_device *hdev = report->device; ++ int ret; ++ ++ ret = hid_hw_power(hdev, PM_HINT_FULLON); ++ if (ret) { ++ hid_err(hdev, "Device didn't resume (%pe)\n", ERR_PTR(ret)); ++ return ret; ++ } ++ ++ ret = hid_set_field(kbd->mode_field, 0, mode); ++ if (ret) { ++ hid_err(hdev, "Failed to set mode field to %u (%pe)\n", mode, ERR_PTR(ret)); ++ goto power_normal; ++ } ++ ++ hid_hw_request(hdev, report, HID_REQ_SET_REPORT); ++ ++ kbd->current_mode = mode; ++ ++power_normal: ++ hid_hw_power(hdev, PM_HINT_NORMAL); ++ ++ return ret; ++} ++ ++static ssize_t mode_show(struct device *dev, ++ struct device_attribute *attr, char *buf) ++{ ++ struct appletb_kbd *kbd = dev_get_drvdata(dev); ++ ++ return sysfs_emit(buf, "%d\n", kbd->current_mode); ++} ++ ++static ssize_t mode_store(struct device *dev, ++ struct device_attribute *attr, ++ const char *buf, size_t size) ++{ ++ struct appletb_kbd *kbd = dev_get_drvdata(dev); ++ u8 mode; ++ int ret; ++ ++ ret = kstrtou8(buf, 0, &mode); ++ if (ret) ++ return ret; ++ ++ if (mode > APPLETB_KBD_MODE_MAX) ++ return -EINVAL; ++ ++ ret = appletb_kbd_set_mode(kbd, mode); ++ ++ return ret < 0 ? ret : size; ++} ++static DEVICE_ATTR_RW(mode); ++ ++struct attribute *appletb_kbd_attrs[] = { ++ &dev_attr_mode.attr, ++ NULL ++}; ++ATTRIBUTE_GROUPS(appletb_kbd); ++ ++static int appletb_tb_key_to_slot(unsigned int code) ++{ ++ switch (code) { ++ case KEY_ESC: ++ return 0; ++ case KEY_F1 ... KEY_F10: ++ return code - KEY_F1 + 1; ++ case KEY_F11 ... KEY_F12: ++ return code - KEY_F11 + 11; ++ ++ default: ++ return -EINVAL; ++ } ++} ++ ++static int appletb_kbd_hid_event(struct hid_device *hdev, struct hid_field *field, ++ struct hid_usage *usage, __s32 value) ++{ ++ struct appletb_kbd *kbd = hid_get_drvdata(hdev); ++ struct key_entry *translation; ++ struct input_dev *input; ++ int slot; ++ ++ if ((usage->hid & HID_USAGE_PAGE) != HID_UP_KEYBOARD || usage->type != EV_KEY) ++ return 0; ++ ++ input = field->hidinput->input; ++ ++ /* ++ * Skip non-touch-bar keys. ++ * ++ * Either the touch bar itself or usbhid generate a slew of key-down ++ * events for all the meta keys. None of which we're at all interested ++ * in. ++ */ ++ slot = appletb_tb_key_to_slot(usage->code); ++ if (slot < 0) ++ return 0; ++ ++ translation = sparse_keymap_entry_from_scancode(input, usage->code); ++ ++ if (translation && kbd->current_mode == APPLETB_KBD_MODE_SPCL) { ++ input_event(input, usage->type, translation->keycode, value); ++ ++ return 1; ++ } ++ ++ return kbd->current_mode == APPLETB_KBD_MODE_OFF; ++} ++ ++static void appletb_kbd_inp_event(struct input_handle *handle, unsigned int type, ++ unsigned int code, int value) ++{ ++ struct appletb_kbd *kbd = handle->private; ++ ++ if (type == EV_KEY && code == KEY_FN && appletb_tb_fn_toggle) { ++ if (value == 1) { ++ kbd->saved_mode = kbd->current_mode; ++ if (kbd->current_mode == APPLETB_KBD_MODE_SPCL) ++ appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_FN); ++ else if (kbd->current_mode == APPLETB_KBD_MODE_FN) ++ appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_SPCL); ++ } else if (value == 0) { ++ if (kbd->saved_mode != kbd->current_mode) ++ appletb_kbd_set_mode(kbd, kbd->saved_mode); ++ } ++ } ++} ++ ++static int appletb_kbd_inp_connect(struct input_handler *handler, ++ struct input_dev *dev, ++ const struct input_device_id *id) ++{ ++ struct appletb_kbd *kbd = handler->private; ++ struct input_handle *handle; ++ int rc; ++ ++ if (id->driver_info == APPLETB_DEVID_KEYBOARD) { ++ handle = &kbd->kbd_handle; ++ handle->name = "tbkbd"; ++ } else { ++ return -ENOENT; ++ } ++ ++ if (handle->dev) ++ return -EEXIST; ++ ++ handle->open = 0; ++ handle->dev = input_get_device(dev); ++ handle->handler = handler; ++ handle->private = kbd; ++ ++ rc = input_register_handle(handle); ++ if (rc) ++ goto err_free_dev; ++ ++ rc = input_open_device(handle); ++ if (rc) ++ goto err_unregister_handle; ++ ++ return 0; ++ ++ err_unregister_handle: ++ input_unregister_handle(handle); ++ err_free_dev: ++ input_put_device(handle->dev); ++ handle->dev = NULL; ++ return rc; ++} ++ ++static void appletb_kbd_inp_disconnect(struct input_handle *handle) ++{ ++ input_close_device(handle); ++ input_unregister_handle(handle); ++ ++ input_put_device(handle->dev); ++ handle->dev = NULL; ++} ++ ++static int appletb_kbd_input_configured(struct hid_device *hdev, struct hid_input *hidinput) ++{ ++ int idx; ++ struct input_dev *input = hidinput->input; ++ ++ /* ++ * Clear various input capabilities that are blindly set by the hid ++ * driver (usbkbd.c) ++ */ ++ memset(input->evbit, 0, sizeof(input->evbit)); ++ memset(input->keybit, 0, sizeof(input->keybit)); ++ memset(input->ledbit, 0, sizeof(input->ledbit)); ++ ++ __set_bit(EV_REP, input->evbit); ++ ++ sparse_keymap_setup(input, appletb_kbd_keymap, NULL); ++ ++ for (idx = 0; appletb_kbd_keymap[idx].type != KE_END; idx++) { ++ input_set_capability(input, EV_KEY, appletb_kbd_keymap[idx].code); ++ } ++ ++ return 0; ++} ++ ++static const struct input_device_id appletb_kbd_input_devices[] = { ++ { ++ .flags = INPUT_DEVICE_ID_MATCH_BUS | ++ INPUT_DEVICE_ID_MATCH_VENDOR | ++ INPUT_DEVICE_ID_MATCH_KEYBIT, ++ .bustype = BUS_USB, ++ .vendor = USB_VENDOR_ID_APPLE, ++ .keybit = { [BIT_WORD(KEY_FN)] = BIT_MASK(KEY_FN) }, ++ .driver_info = APPLETB_DEVID_KEYBOARD, ++ }, ++ { } ++}; ++ ++static bool appletb_kbd_match_internal_device(struct input_handler *handler, ++ struct input_dev *inp_dev) ++{ ++ struct device *dev = &inp_dev->dev; ++ ++ /* in kernel: dev && !is_usb_device(dev) */ ++ while (dev && !(dev->type && dev->type->name && ++ !strcmp(dev->type->name, "usb_device"))) ++ dev = dev->parent; ++ ++ /* ++ * Apple labels all their internal keyboards and trackpads as such, ++ * instead of maintaining an ever expanding list of product-id's we ++ * just look at the device's product name. ++ */ ++ if (dev) ++ return !!strstr(to_usb_device(dev)->product, "Internal Keyboard"); ++ ++ return false; ++} ++ ++static int appletb_kbd_probe(struct hid_device *hdev, const struct hid_device_id *id) ++{ ++ struct appletb_kbd *kbd; ++ struct device *dev = &hdev->dev; ++ struct hid_field *mode_field; ++ int ret; ++ ++ ret = hid_parse(hdev); ++ if (ret) ++ return dev_err_probe(dev, ret, "HID parse failed\n"); ++ ++ mode_field = hid_find_field(hdev, HID_OUTPUT_REPORT, ++ HID_GD_KEYBOARD, HID_USAGE_MODE); ++ if (!mode_field) ++ return -ENODEV; ++ ++ kbd = devm_kzalloc(dev, sizeof(*kbd), GFP_KERNEL); ++ if (!kbd) ++ return -ENOMEM; ++ ++ kbd->mode_field = mode_field; ++ ++ ret = hid_hw_start(hdev, HID_CONNECT_HIDINPUT); ++ if (ret) ++ return dev_err_probe(dev, ret, "HID hw start failed\n"); ++ ++ ret = hid_hw_open(hdev); ++ if (ret) { ++ dev_err_probe(dev, ret, "HID hw open failed\n"); ++ goto stop_hw; ++ } ++ ++ kbd->inp_handler.event = appletb_kbd_inp_event; ++ kbd->inp_handler.connect = appletb_kbd_inp_connect; ++ kbd->inp_handler.disconnect = appletb_kbd_inp_disconnect; ++ kbd->inp_handler.name = "appletb"; ++ kbd->inp_handler.id_table = appletb_kbd_input_devices; ++ kbd->inp_handler.match = appletb_kbd_match_internal_device; ++ kbd->inp_handler.private = kbd; ++ ++ ret = input_register_handler(&kbd->inp_handler); ++ if (ret) { ++ dev_err_probe(dev, ret, "Unable to register keyboard handler\n"); ++ goto close_hw; ++ } ++ ++ ret = appletb_kbd_set_mode(kbd, appletb_tb_def_mode); ++ if (ret) { ++ dev_err_probe(dev, ret, "Failed to set touchbar mode\n"); ++ goto close_hw; ++ } ++ ++ hid_set_drvdata(hdev, kbd); ++ ++ return 0; ++ ++close_hw: ++ hid_hw_close(hdev); ++stop_hw: ++ hid_hw_stop(hdev); ++ return ret; ++} ++ ++static void appletb_kbd_remove(struct hid_device *hdev) ++{ ++ struct appletb_kbd *kbd = hid_get_drvdata(hdev); ++ ++ appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_OFF); ++ ++ input_unregister_handler(&kbd->inp_handler); ++ ++ hid_hw_close(hdev); ++ hid_hw_stop(hdev); ++} ++ ++#ifdef CONFIG_PM ++static int appletb_kbd_suspend(struct hid_device *hdev, pm_message_t msg) ++{ ++ struct appletb_kbd *kbd = hid_get_drvdata(hdev); ++ ++ kbd->saved_mode = kbd->current_mode; ++ appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_OFF); ++ ++ return 0; ++} ++ ++static int appletb_kbd_reset_resume(struct hid_device *hdev) ++{ ++ struct appletb_kbd *kbd = hid_get_drvdata(hdev); ++ ++ appletb_kbd_set_mode(kbd, kbd->saved_mode); ++ ++ return 0; ++} ++#endif ++ ++static const struct hid_device_id appletb_kbd_hid_ids[] = { ++ /* MacBook Pro's 2018, 2019, with T2 chip: iBridge Display */ ++ { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, ++ { } ++}; ++MODULE_DEVICE_TABLE(hid, appletb_kbd_hid_ids); ++ ++static struct hid_driver appletb_kbd_hid_driver = { ++ .name = "hid-appletb-kbd", ++ .id_table = appletb_kbd_hid_ids, ++ .probe = appletb_kbd_probe, ++ .remove = appletb_kbd_remove, ++ .event = appletb_kbd_hid_event, ++ .input_configured = appletb_kbd_input_configured, ++#ifdef CONFIG_PM ++ .suspend = appletb_kbd_suspend, ++ .reset_resume = appletb_kbd_reset_resume, ++#endif ++ .driver.dev_groups = appletb_kbd_groups, ++}; ++module_hid_driver(appletb_kbd_hid_driver); ++ ++MODULE_AUTHOR("Ronald Tschalär"); ++MODULE_AUTHOR("Kerem Karabay "); ++MODULE_DESCRIPTION("MacBookPro Touch Bar Keyboard Mode Driver"); ++MODULE_LICENSE("GPL"); +diff --git a/drivers/hid/hid-core.c b/drivers/hid/hid-core.c +index 74efda212c55..f4379efdbf30 100644 +--- a/drivers/hid/hid-core.c ++++ b/drivers/hid/hid-core.c +@@ -1912,6 +1912,31 @@ int hid_set_field(struct hid_field *field, unsigned offset, __s32 value) + } + EXPORT_SYMBOL_GPL(hid_set_field); + ++struct hid_field *hid_find_field(struct hid_device *hdev, unsigned int report_type, ++ unsigned int application, unsigned int usage) ++{ ++ struct list_head *report_list = &hdev->report_enum[report_type].report_list; ++ struct hid_report *report; ++ int i, j; ++ ++ list_for_each_entry(report, report_list, list) { ++ if (report->application != application) ++ continue; ++ ++ for (i = 0; i < report->maxfield; i++) { ++ struct hid_field *field = report->field[i]; ++ ++ for (j = 0; j < field->maxusage; j++) { ++ if (field->usage[j].hid == usage) ++ return field; ++ } ++ } ++ } ++ ++ return NULL; ++} ++EXPORT_SYMBOL_GPL(hid_find_field); ++ + static struct hid_report *hid_get_report(struct hid_report_enum *report_enum, + const u8 *data) + { +diff --git a/drivers/hid/hid-google-hammer.c b/drivers/hid/hid-google-hammer.c +index 25331695ae32..3380694ba18c 100644 +--- a/drivers/hid/hid-google-hammer.c ++++ b/drivers/hid/hid-google-hammer.c +@@ -418,38 +418,15 @@ static int hammer_event(struct hid_device *hid, struct hid_field *field, + return 0; + } + +-static bool hammer_has_usage(struct hid_device *hdev, unsigned int report_type, +- unsigned application, unsigned usage) +-{ +- struct hid_report_enum *re = &hdev->report_enum[report_type]; +- struct hid_report *report; +- int i, j; +- +- list_for_each_entry(report, &re->report_list, list) { +- if (report->application != application) +- continue; +- +- for (i = 0; i < report->maxfield; i++) { +- struct hid_field *field = report->field[i]; +- +- for (j = 0; j < field->maxusage; j++) +- if (field->usage[j].hid == usage) +- return true; +- } +- } +- +- return false; +-} +- + static bool hammer_has_folded_event(struct hid_device *hdev) + { +- return hammer_has_usage(hdev, HID_INPUT_REPORT, ++ return !!hid_find_field(hdev, HID_INPUT_REPORT, + HID_GD_KEYBOARD, HID_USAGE_KBD_FOLDED); + } + + static bool hammer_has_backlight_control(struct hid_device *hdev) + { +- return hammer_has_usage(hdev, HID_OUTPUT_REPORT, ++ return !!hid_find_field(hdev, HID_OUTPUT_REPORT, + HID_GD_KEYBOARD, HID_AD_BRIGHTNESS); + } + +diff --git a/drivers/hid/hid-multitouch.c b/drivers/hid/hid-multitouch.c +index 56fc78841f24..0fed955364c3 100644 +--- a/drivers/hid/hid-multitouch.c ++++ b/drivers/hid/hid-multitouch.c +@@ -72,6 +72,7 @@ MODULE_LICENSE("GPL"); + #define MT_QUIRK_FORCE_MULTI_INPUT BIT(20) + #define MT_QUIRK_DISABLE_WAKEUP BIT(21) + #define MT_QUIRK_ORIENTATION_INVERT BIT(22) ++#define MT_QUIRK_TOUCH_IS_TIPSTATE BIT(23) + + #define MT_INPUTMODE_TOUCHSCREEN 0x02 + #define MT_INPUTMODE_TOUCHPAD 0x03 +@@ -145,6 +146,7 @@ struct mt_class { + __s32 sn_height; /* Signal/noise ratio for height events */ + __s32 sn_pressure; /* Signal/noise ratio for pressure events */ + __u8 maxcontacts; ++ bool is_direct; /* true for touchscreens */ + bool is_indirect; /* true for touchpads */ + bool export_all_inputs; /* do not ignore mouse, keyboards, etc... */ + }; +@@ -212,6 +214,7 @@ static void mt_post_parse(struct mt_device *td, struct mt_application *app); + #define MT_CLS_GOOGLE 0x0111 + #define MT_CLS_RAZER_BLADE_STEALTH 0x0112 + #define MT_CLS_SMART_TECH 0x0113 ++#define MT_CLS_APPLE_TOUCHBAR 0x0114 + + #define MT_DEFAULT_MAXCONTACT 10 + #define MT_MAX_MAXCONTACT 250 +@@ -396,6 +399,13 @@ static const struct mt_class mt_classes[] = { + MT_QUIRK_CONTACT_CNT_ACCURATE | + MT_QUIRK_SEPARATE_APP_REPORT, + }, ++ { .name = MT_CLS_APPLE_TOUCHBAR, ++ .quirks = MT_QUIRK_HOVERING | ++ MT_QUIRK_TOUCH_IS_TIPSTATE | ++ MT_QUIRK_SLOT_IS_CONTACTID_MINUS_ONE, ++ .is_direct = true, ++ .maxcontacts = 11, ++ }, + { } + }; + +@@ -489,9 +499,6 @@ static void mt_feature_mapping(struct hid_device *hdev, + if (!td->maxcontacts && + field->logical_maximum <= MT_MAX_MAXCONTACT) + td->maxcontacts = field->logical_maximum; +- if (td->mtclass.maxcontacts) +- /* check if the maxcontacts is given by the class */ +- td->maxcontacts = td->mtclass.maxcontacts; + + break; + case HID_DG_BUTTONTYPE: +@@ -565,13 +572,13 @@ static struct mt_application *mt_allocate_application(struct mt_device *td, + mt_application->application = application; + INIT_LIST_HEAD(&mt_application->mt_usages); + +- if (application == HID_DG_TOUCHSCREEN) ++ if (application == HID_DG_TOUCHSCREEN && !td->mtclass.is_indirect) + mt_application->mt_flags |= INPUT_MT_DIRECT; + + /* + * Model touchscreens providing buttons as touchpads. + */ +- if (application == HID_DG_TOUCHPAD) { ++ if (application == HID_DG_TOUCHPAD && !td->mtclass.is_direct) { + mt_application->mt_flags |= INPUT_MT_POINTER; + td->inputmode_value = MT_INPUTMODE_TOUCHPAD; + } +@@ -635,7 +642,9 @@ static struct mt_report_data *mt_allocate_report_data(struct mt_device *td, + + if (field->logical == HID_DG_FINGER || td->hdev->group != HID_GROUP_MULTITOUCH_WIN_8) { + for (n = 0; n < field->report_count; n++) { +- if (field->usage[n].hid == HID_DG_CONTACTID) { ++ unsigned int hid = field->usage[n].hid; ++ ++ if (hid == HID_DG_CONTACTID || hid == HID_DG_TRANSDUCER_INDEX) { + rdata->is_mt_collection = true; + break; + } +@@ -807,6 +816,15 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, + + MT_STORE_FIELD(confidence_state); + return 1; ++ case HID_DG_TOUCH: ++ /* ++ * Legacy devices use TIPSWITCH and not TOUCH. ++ * Let's just ignore this field unless the quirk is set. ++ */ ++ if (!(cls->quirks & MT_QUIRK_TOUCH_IS_TIPSTATE)) ++ return -1; ++ ++ fallthrough; + case HID_DG_TIPSWITCH: + if (field->application != HID_GD_SYSTEM_MULTIAXIS) + input_set_capability(hi->input, +@@ -814,6 +832,7 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, + MT_STORE_FIELD(tip_state); + return 1; + case HID_DG_CONTACTID: ++ case HID_DG_TRANSDUCER_INDEX: + MT_STORE_FIELD(contactid); + app->touches_by_report++; + return 1; +@@ -869,10 +888,6 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, + case HID_DG_CONTACTMAX: + /* contact max are global to the report */ + return -1; +- case HID_DG_TOUCH: +- /* Legacy devices use TIPSWITCH and not TOUCH. +- * Let's just ignore this field. */ +- return -1; + } + /* let hid-input decide for the others */ + return 0; +@@ -1300,6 +1315,10 @@ static int mt_touch_input_configured(struct hid_device *hdev, + struct input_dev *input = hi->input; + int ret; + ++ /* check if the maxcontacts is given by the class */ ++ if (cls->maxcontacts) ++ td->maxcontacts = cls->maxcontacts; ++ + if (!td->maxcontacts) + td->maxcontacts = MT_DEFAULT_MAXCONTACT; + +@@ -1307,6 +1326,9 @@ static int mt_touch_input_configured(struct hid_device *hdev, + if (td->serial_maybe) + mt_post_parse_default_settings(td, app); + ++ if (cls->is_direct) ++ app->mt_flags |= INPUT_MT_DIRECT; ++ + if (cls->is_indirect) + app->mt_flags |= INPUT_MT_POINTER; + +@@ -1733,6 +1755,15 @@ static int mt_probe(struct hid_device *hdev, const struct hid_device_id *id) + } + } + ++ ret = hid_parse(hdev); ++ if (ret != 0) ++ return ret; ++ ++ if (mtclass->name == MT_CLS_APPLE_TOUCHBAR && ++ !hid_find_field(hdev, HID_INPUT_REPORT, ++ HID_DG_TOUCHPAD, HID_DG_TRANSDUCER_INDEX)) ++ return -ENODEV; ++ + td = devm_kzalloc(&hdev->dev, sizeof(struct mt_device), GFP_KERNEL); + if (!td) { + dev_err(&hdev->dev, "cannot allocate multitouch data\n"); +@@ -1780,10 +1811,6 @@ static int mt_probe(struct hid_device *hdev, const struct hid_device_id *id) + + timer_setup(&td->release_timer, mt_expired_timeout, 0); + +- ret = hid_parse(hdev); +- if (ret != 0) +- return ret; +- + if (mtclass->quirks & MT_QUIRK_FIX_CONST_CONTACT_ID) + mt_fix_const_fields(hdev, HID_DG_CONTACTID); + +@@ -2235,6 +2262,11 @@ static const struct hid_device_id mt_devices[] = { + MT_USB_DEVICE(USB_VENDOR_ID_XIROKU, + USB_DEVICE_ID_XIROKU_CSR2) }, + ++ /* Apple Touch Bars */ ++ { .driver_data = MT_CLS_APPLE_TOUCHBAR, ++ HID_USB_DEVICE(USB_VENDOR_ID_APPLE, ++ USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, ++ + /* Google MT devices */ + { .driver_data = MT_CLS_GOOGLE, + HID_DEVICE(HID_BUS_ANY, HID_GROUP_ANY, USB_VENDOR_ID_GOOGLE, +diff --git a/drivers/hid/hid-quirks.c b/drivers/hid/hid-quirks.c +index e0bbf0c6345d..7c576d6540fe 100644 +--- a/drivers/hid/hid-quirks.c ++++ b/drivers/hid/hid-quirks.c +@@ -328,8 +328,6 @@ static const struct hid_device_id hid_have_special_driver[] = { + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_GEYSER1_TP_ONLY) }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_2021) }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_FINGERPRINT_2021) }, +- { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, +- { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, + #endif + #if IS_ENABLED(CONFIG_HID_APPLEIR) + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL) }, +@@ -338,6 +336,12 @@ static const struct hid_device_id hid_have_special_driver[] = { + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL4) }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL5) }, + #endif ++#if IS_ENABLED(CONFIG_HID_APPLETB_BL) ++ { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, ++#endif ++#if IS_ENABLED(CONFIG_HID_APPLETB_KBD) ++ { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, ++#endif + #if IS_ENABLED(CONFIG_HID_ASUS) + { HID_I2C_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_I2C_KEYBOARD) }, + { HID_I2C_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_I2C_TOUCHPAD) }, +diff --git a/drivers/hwmon/applesmc.c b/drivers/hwmon/applesmc.c +index fc6d6a9053ce..698f44794453 100644 +--- a/drivers/hwmon/applesmc.c ++++ b/drivers/hwmon/applesmc.c +@@ -6,6 +6,7 @@ + * + * Copyright (C) 2007 Nicolas Boichat + * Copyright (C) 2010 Henrik Rydberg ++ * Copyright (C) 2019 Paul Pawlowski + * + * Based on hdaps.c driver: + * Copyright (C) 2005 Robert Love +@@ -18,7 +19,7 @@ + #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + + #include +-#include ++#include + #include + #include + #include +@@ -35,12 +36,24 @@ + #include + + /* data port used by Apple SMC */ +-#define APPLESMC_DATA_PORT 0x300 ++#define APPLESMC_DATA_PORT 0 + /* command/status port used by Apple SMC */ +-#define APPLESMC_CMD_PORT 0x304 ++#define APPLESMC_CMD_PORT 4 + + #define APPLESMC_NR_PORTS 32 /* 0x300-0x31f */ + ++#define APPLESMC_IOMEM_KEY_DATA 0 ++#define APPLESMC_IOMEM_KEY_STATUS 0x4005 ++#define APPLESMC_IOMEM_KEY_NAME 0x78 ++#define APPLESMC_IOMEM_KEY_DATA_LEN 0x7D ++#define APPLESMC_IOMEM_KEY_SMC_ID 0x7E ++#define APPLESMC_IOMEM_KEY_CMD 0x7F ++#define APPLESMC_IOMEM_MIN_SIZE 0x4006 ++ ++#define APPLESMC_IOMEM_KEY_TYPE_CODE 0 ++#define APPLESMC_IOMEM_KEY_TYPE_DATA_LEN 5 ++#define APPLESMC_IOMEM_KEY_TYPE_FLAGS 6 ++ + #define APPLESMC_MAX_DATA_LENGTH 32 + + /* Apple SMC status bits */ +@@ -74,6 +87,7 @@ + #define FAN_ID_FMT "F%dID" /* r-o char[16] */ + + #define TEMP_SENSOR_TYPE "sp78" ++#define FLOAT_TYPE "flt " + + /* List of keys used to read/write fan speeds */ + static const char *const fan_speed_fmt[] = { +@@ -83,6 +97,7 @@ static const char *const fan_speed_fmt[] = { + "F%dSf", /* safe speed - not all models */ + "F%dTg", /* target speed (manual: rw) */ + }; ++#define FAN_MANUAL_FMT "F%dMd" + + #define INIT_TIMEOUT_MSECS 5000 /* wait up to 5s for device init ... */ + #define INIT_WAIT_MSECS 50 /* ... in 50ms increments */ +@@ -119,7 +134,7 @@ struct applesmc_entry { + }; + + /* Register lookup and registers common to all SMCs */ +-static struct applesmc_registers { ++struct applesmc_registers { + struct mutex mutex; /* register read/write mutex */ + unsigned int key_count; /* number of SMC registers */ + unsigned int fan_count; /* number of fans */ +@@ -133,26 +148,38 @@ static struct applesmc_registers { + bool init_complete; /* true when fully initialized */ + struct applesmc_entry *cache; /* cached key entries */ + const char **index; /* temperature key index */ +-} smcreg = { +- .mutex = __MUTEX_INITIALIZER(smcreg.mutex), + }; + +-static const int debug; +-static struct platform_device *pdev; +-static s16 rest_x; +-static s16 rest_y; +-static u8 backlight_state[2]; ++struct applesmc_device { ++ struct acpi_device *dev; ++ struct device *ldev; ++ struct applesmc_registers reg; + +-static struct device *hwmon_dev; +-static struct input_dev *applesmc_idev; ++ bool port_base_set, iomem_base_set; ++ u16 port_base; ++ u8 *__iomem iomem_base; ++ u32 iomem_base_addr, iomem_base_size; + +-/* +- * Last index written to key_at_index sysfs file, and value to use for all other +- * key_at_index_* sysfs files. +- */ +-static unsigned int key_at_index; ++ s16 rest_x; ++ s16 rest_y; ++ ++ u8 backlight_state[2]; ++ ++ struct device *hwmon_dev; ++ struct input_dev *idev; ++ ++ /* ++ * Last index written to key_at_index sysfs file, and value to use for all other ++ * key_at_index_* sysfs files. ++ */ ++ unsigned int key_at_index; + +-static struct workqueue_struct *applesmc_led_wq; ++ struct workqueue_struct *backlight_wq; ++ struct work_struct backlight_work; ++ struct led_classdev backlight_dev; ++}; ++ ++static const int debug; + + /* + * Wait for specific status bits with a mask on the SMC. +@@ -162,7 +189,7 @@ static struct workqueue_struct *applesmc_led_wq; + * run out past 500ms. + */ + +-static int wait_status(u8 val, u8 mask) ++static int port_wait_status(struct applesmc_device *smc, u8 val, u8 mask) + { + u8 status; + int us; +@@ -170,7 +197,7 @@ static int wait_status(u8 val, u8 mask) + + us = APPLESMC_MIN_WAIT; + for (i = 0; i < 24 ; i++) { +- status = inb(APPLESMC_CMD_PORT); ++ status = inb(smc->port_base + APPLESMC_CMD_PORT); + if ((status & mask) == val) + return 0; + usleep_range(us, us * 2); +@@ -180,13 +207,13 @@ static int wait_status(u8 val, u8 mask) + return -EIO; + } + +-/* send_byte - Write to SMC data port. Callers must hold applesmc_lock. */ ++/* port_send_byte - Write to SMC data port. Callers must hold applesmc_lock. */ + +-static int send_byte(u8 cmd, u16 port) ++static int port_send_byte(struct applesmc_device *smc, u8 cmd, u16 port) + { + int status; + +- status = wait_status(0, SMC_STATUS_IB_CLOSED); ++ status = port_wait_status(smc, 0, SMC_STATUS_IB_CLOSED); + if (status) + return status; + /* +@@ -195,24 +222,25 @@ static int send_byte(u8 cmd, u16 port) + * this extra read may not happen if status returns both + * simultaneously and this would appear to be required. + */ +- status = wait_status(SMC_STATUS_BUSY, SMC_STATUS_BUSY); ++ status = port_wait_status(smc, SMC_STATUS_BUSY, SMC_STATUS_BUSY); + if (status) + return status; + +- outb(cmd, port); ++ outb(cmd, smc->port_base + port); + return 0; + } + +-/* send_command - Write a command to the SMC. Callers must hold applesmc_lock. */ ++/* port_send_command - Write a command to the SMC. Callers must hold applesmc_lock. */ + +-static int send_command(u8 cmd) ++static int port_send_command(struct applesmc_device *smc, u8 cmd) + { + int ret; + +- ret = wait_status(0, SMC_STATUS_IB_CLOSED); ++ ret = port_wait_status(smc, 0, SMC_STATUS_IB_CLOSED); + if (ret) + return ret; +- outb(cmd, APPLESMC_CMD_PORT); ++ ++ outb(cmd, smc->port_base + APPLESMC_CMD_PORT); + return 0; + } + +@@ -222,110 +250,304 @@ static int send_command(u8 cmd) + * If busy is stuck high after the command then the SMC is jammed. + */ + +-static int smc_sane(void) ++static int port_smc_sane(struct applesmc_device *smc) + { + int ret; + +- ret = wait_status(0, SMC_STATUS_BUSY); ++ ret = port_wait_status(smc, 0, SMC_STATUS_BUSY); + if (!ret) + return ret; +- ret = send_command(APPLESMC_READ_CMD); ++ ret = port_send_command(smc, APPLESMC_READ_CMD); + if (ret) + return ret; +- return wait_status(0, SMC_STATUS_BUSY); ++ return port_wait_status(smc, 0, SMC_STATUS_BUSY); + } + +-static int send_argument(const char *key) ++static int port_send_argument(struct applesmc_device *smc, const char *key) + { + int i; + + for (i = 0; i < 4; i++) +- if (send_byte(key[i], APPLESMC_DATA_PORT)) ++ if (port_send_byte(smc, key[i], APPLESMC_DATA_PORT)) + return -EIO; + return 0; + } + +-static int read_smc(u8 cmd, const char *key, u8 *buffer, u8 len) ++static int port_read_smc(struct applesmc_device *smc, u8 cmd, const char *key, ++ u8 *buffer, u8 len) + { + u8 status, data = 0; + int i; + int ret; + +- ret = smc_sane(); ++ ret = port_smc_sane(smc); + if (ret) + return ret; + +- if (send_command(cmd) || send_argument(key)) { ++ if (port_send_command(smc, cmd) || port_send_argument(smc, key)) { + pr_warn("%.4s: read arg fail\n", key); + return -EIO; + } + + /* This has no effect on newer (2012) SMCs */ +- if (send_byte(len, APPLESMC_DATA_PORT)) { ++ if (port_send_byte(smc, len, APPLESMC_DATA_PORT)) { + pr_warn("%.4s: read len fail\n", key); + return -EIO; + } + + for (i = 0; i < len; i++) { +- if (wait_status(SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY, ++ if (port_wait_status(smc, ++ SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY, + SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY)) { + pr_warn("%.4s: read data[%d] fail\n", key, i); + return -EIO; + } +- buffer[i] = inb(APPLESMC_DATA_PORT); ++ buffer[i] = inb(smc->port_base + APPLESMC_DATA_PORT); + } + + /* Read the data port until bit0 is cleared */ + for (i = 0; i < 16; i++) { + udelay(APPLESMC_MIN_WAIT); +- status = inb(APPLESMC_CMD_PORT); ++ status = inb(smc->port_base + APPLESMC_CMD_PORT); + if (!(status & SMC_STATUS_AWAITING_DATA)) + break; +- data = inb(APPLESMC_DATA_PORT); ++ data = inb(smc->port_base + APPLESMC_DATA_PORT); + } + if (i) + pr_warn("flushed %d bytes, last value is: %d\n", i, data); + +- return wait_status(0, SMC_STATUS_BUSY); ++ return port_wait_status(smc, 0, SMC_STATUS_BUSY); + } + +-static int write_smc(u8 cmd, const char *key, const u8 *buffer, u8 len) ++static int port_write_smc(struct applesmc_device *smc, u8 cmd, const char *key, ++ const u8 *buffer, u8 len) + { + int i; + int ret; + +- ret = smc_sane(); ++ ret = port_smc_sane(smc); + if (ret) + return ret; + +- if (send_command(cmd) || send_argument(key)) { ++ if (port_send_command(smc, cmd) || port_send_argument(smc, key)) { + pr_warn("%s: write arg fail\n", key); + return -EIO; + } + +- if (send_byte(len, APPLESMC_DATA_PORT)) { ++ if (port_send_byte(smc, len, APPLESMC_DATA_PORT)) { + pr_warn("%.4s: write len fail\n", key); + return -EIO; + } + + for (i = 0; i < len; i++) { +- if (send_byte(buffer[i], APPLESMC_DATA_PORT)) { ++ if (port_send_byte(smc, buffer[i], APPLESMC_DATA_PORT)) { + pr_warn("%s: write data fail\n", key); + return -EIO; + } + } + +- return wait_status(0, SMC_STATUS_BUSY); ++ return port_wait_status(smc, 0, SMC_STATUS_BUSY); + } + +-static int read_register_count(unsigned int *count) ++static int port_get_smc_key_info(struct applesmc_device *smc, ++ const char *key, struct applesmc_entry *info) + { +- __be32 be; + int ret; ++ u8 raw[6]; + +- ret = read_smc(APPLESMC_READ_CMD, KEY_COUNT_KEY, (u8 *)&be, 4); ++ ret = port_read_smc(smc, APPLESMC_GET_KEY_TYPE_CMD, key, raw, 6); + if (ret) + return ret; ++ info->len = raw[0]; ++ memcpy(info->type, &raw[1], 4); ++ info->flags = raw[5]; ++ return 0; ++} ++ ++ ++/* ++ * MMIO based communication. ++ * TODO: Use updated mechanism for cmd timeout/retry ++ */ ++ ++static void iomem_clear_status(struct applesmc_device *smc) ++{ ++ if (ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS)) ++ iowrite8(0, smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); ++} ++ ++static int iomem_wait_read(struct applesmc_device *smc) ++{ ++ u8 status; ++ int us; ++ int i; ++ ++ us = APPLESMC_MIN_WAIT; ++ for (i = 0; i < 24 ; i++) { ++ status = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); ++ if (status & 0x20) ++ return 0; ++ usleep_range(us, us * 2); ++ if (i > 9) ++ us <<= 1; ++ } ++ ++ dev_warn(smc->ldev, "%s... timeout\n", __func__); ++ return -EIO; ++} ++ ++static int iomem_read_smc(struct applesmc_device *smc, u8 cmd, const char *key, ++ u8 *buffer, u8 len) ++{ ++ u8 err, remote_len; ++ u32 key_int = *((u32 *) key); ++ ++ iomem_clear_status(smc); ++ iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); ++ iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); ++ iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); ++ ++ if (iomem_wait_read(smc)) ++ return -EIO; ++ ++ err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); ++ if (err != 0) { ++ dev_warn(smc->ldev, "read_smc_mmio(%x %8x/%.4s) failed: %u\n", ++ cmd, key_int, key, err); ++ return -EIO; ++ } ++ ++ if (cmd == APPLESMC_READ_CMD) { ++ remote_len = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_DATA_LEN); ++ if (remote_len != len) { ++ dev_warn(smc->ldev, ++ "read_smc_mmio(%x %8x/%.4s) failed: buffer length mismatch (remote = %u, requested = %u)\n", ++ cmd, key_int, key, remote_len, len); ++ return -EINVAL; ++ } ++ } else { ++ remote_len = len; ++ } ++ ++ memcpy_fromio(buffer, smc->iomem_base + APPLESMC_IOMEM_KEY_DATA, ++ remote_len); ++ ++ dev_dbg(smc->ldev, "read_smc_mmio(%x %8x/%.4s): buflen=%u reslen=%u\n", ++ cmd, key_int, key, len, remote_len); ++ print_hex_dump_bytes("read_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, remote_len); ++ return 0; ++} ++ ++static int iomem_get_smc_key_type(struct applesmc_device *smc, const char *key, ++ struct applesmc_entry *e) ++{ ++ u8 err; ++ u8 cmd = APPLESMC_GET_KEY_TYPE_CMD; ++ u32 key_int = *((u32 *) key); ++ ++ iomem_clear_status(smc); ++ iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); ++ iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); ++ iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); ++ ++ if (iomem_wait_read(smc)) ++ return -EIO; ++ ++ err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); ++ if (err != 0) { ++ dev_warn(smc->ldev, "get_smc_key_type_mmio(%.4s) failed: %u\n", key, err); ++ return -EIO; ++ } ++ ++ e->len = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_DATA_LEN); ++ *((uint32_t *) e->type) = ioread32( ++ smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_CODE); ++ e->flags = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_FLAGS); ++ ++ dev_dbg(smc->ldev, "get_smc_key_type_mmio(%.4s): len=%u type=%.4s flags=%x\n", ++ key, e->len, e->type, e->flags); ++ return 0; ++} ++ ++static int iomem_write_smc(struct applesmc_device *smc, u8 cmd, const char *key, ++ const u8 *buffer, u8 len) ++{ ++ u8 err; ++ u32 key_int = *((u32 *) key); ++ ++ iomem_clear_status(smc); ++ iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); ++ memcpy_toio(smc->iomem_base + APPLESMC_IOMEM_KEY_DATA, buffer, len); ++ iowrite32(len, smc->iomem_base + APPLESMC_IOMEM_KEY_DATA_LEN); ++ iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); ++ iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); ++ ++ if (iomem_wait_read(smc)) ++ return -EIO; ++ ++ err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); ++ if (err != 0) { ++ dev_warn(smc->ldev, "write_smc_mmio(%x %.4s) failed: %u\n", cmd, key, err); ++ print_hex_dump_bytes("write_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, len); ++ return -EIO; ++ } ++ ++ dev_dbg(smc->ldev, "write_smc_mmio(%x %.4s): buflen=%u\n", cmd, key, len); ++ print_hex_dump_bytes("write_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, len); ++ return 0; ++} ++ ++ ++static int read_smc(struct applesmc_device *smc, const char *key, ++ u8 *buffer, u8 len) ++{ ++ if (smc->iomem_base_set) ++ return iomem_read_smc(smc, APPLESMC_READ_CMD, key, buffer, len); ++ else ++ return port_read_smc(smc, APPLESMC_READ_CMD, key, buffer, len); ++} ++ ++static int write_smc(struct applesmc_device *smc, const char *key, ++ const u8 *buffer, u8 len) ++{ ++ if (smc->iomem_base_set) ++ return iomem_write_smc(smc, APPLESMC_WRITE_CMD, key, buffer, len); ++ else ++ return port_write_smc(smc, APPLESMC_WRITE_CMD, key, buffer, len); ++} ++ ++static int get_smc_key_by_index(struct applesmc_device *smc, ++ unsigned int index, char *key) ++{ ++ __be32 be; ++ ++ be = cpu_to_be32(index); ++ if (smc->iomem_base_set) ++ return iomem_read_smc(smc, APPLESMC_GET_KEY_BY_INDEX_CMD, ++ (const char *) &be, (u8 *) key, 4); ++ else ++ return port_read_smc(smc, APPLESMC_GET_KEY_BY_INDEX_CMD, ++ (const char *) &be, (u8 *) key, 4); ++} ++ ++static int get_smc_key_info(struct applesmc_device *smc, const char *key, ++ struct applesmc_entry *info) ++{ ++ if (smc->iomem_base_set) ++ return iomem_get_smc_key_type(smc, key, info); ++ else ++ return port_get_smc_key_info(smc, key, info); ++} ++ ++static int read_register_count(struct applesmc_device *smc, ++ unsigned int *count) ++{ ++ __be32 be; ++ int ret; ++ ++ ret = read_smc(smc, KEY_COUNT_KEY, (u8 *)&be, 4); ++ if (ret < 0) ++ return ret; + + *count = be32_to_cpu(be); + return 0; +@@ -338,76 +560,73 @@ static int read_register_count(unsigned int *count) + * All functions below are concurrency safe - callers should NOT hold lock. + */ + +-static int applesmc_read_entry(const struct applesmc_entry *entry, +- u8 *buf, u8 len) ++static int applesmc_read_entry(struct applesmc_device *smc, ++ const struct applesmc_entry *entry, u8 *buf, u8 len) + { + int ret; + + if (entry->len != len) + return -EINVAL; +- mutex_lock(&smcreg.mutex); +- ret = read_smc(APPLESMC_READ_CMD, entry->key, buf, len); +- mutex_unlock(&smcreg.mutex); ++ mutex_lock(&smc->reg.mutex); ++ ret = read_smc(smc, entry->key, buf, len); ++ mutex_unlock(&smc->reg.mutex); + + return ret; + } + +-static int applesmc_write_entry(const struct applesmc_entry *entry, +- const u8 *buf, u8 len) ++static int applesmc_write_entry(struct applesmc_device *smc, ++ const struct applesmc_entry *entry, const u8 *buf, u8 len) + { + int ret; + + if (entry->len != len) + return -EINVAL; +- mutex_lock(&smcreg.mutex); +- ret = write_smc(APPLESMC_WRITE_CMD, entry->key, buf, len); +- mutex_unlock(&smcreg.mutex); ++ mutex_lock(&smc->reg.mutex); ++ ret = write_smc(smc, entry->key, buf, len); ++ mutex_unlock(&smc->reg.mutex); + return ret; + } + +-static const struct applesmc_entry *applesmc_get_entry_by_index(int index) ++static const struct applesmc_entry *applesmc_get_entry_by_index( ++ struct applesmc_device *smc, int index) + { +- struct applesmc_entry *cache = &smcreg.cache[index]; +- u8 key[4], info[6]; +- __be32 be; ++ struct applesmc_entry *cache = &smc->reg.cache[index]; ++ char key[4]; + int ret = 0; + + if (cache->valid) + return cache; + +- mutex_lock(&smcreg.mutex); ++ mutex_lock(&smc->reg.mutex); + + if (cache->valid) + goto out; +- be = cpu_to_be32(index); +- ret = read_smc(APPLESMC_GET_KEY_BY_INDEX_CMD, (u8 *)&be, key, 4); ++ ret = get_smc_key_by_index(smc, index, key); + if (ret) + goto out; +- ret = read_smc(APPLESMC_GET_KEY_TYPE_CMD, key, info, 6); ++ memcpy(cache->key, key, 4); ++ ++ ret = get_smc_key_info(smc, key, cache); + if (ret) + goto out; +- +- memcpy(cache->key, key, 4); +- cache->len = info[0]; +- memcpy(cache->type, &info[1], 4); +- cache->flags = info[5]; + cache->valid = true; + + out: +- mutex_unlock(&smcreg.mutex); ++ mutex_unlock(&smc->reg.mutex); + if (ret) + return ERR_PTR(ret); + return cache; + } + +-static int applesmc_get_lower_bound(unsigned int *lo, const char *key) ++static int applesmc_get_lower_bound(struct applesmc_device *smc, ++ unsigned int *lo, const char *key) + { +- int begin = 0, end = smcreg.key_count; ++ int begin = 0, end = smc->reg.key_count; + const struct applesmc_entry *entry; + + while (begin != end) { + int middle = begin + (end - begin) / 2; +- entry = applesmc_get_entry_by_index(middle); ++ entry = applesmc_get_entry_by_index(smc, middle); + if (IS_ERR(entry)) { + *lo = 0; + return PTR_ERR(entry); +@@ -422,16 +641,17 @@ static int applesmc_get_lower_bound(unsigned int *lo, const char *key) + return 0; + } + +-static int applesmc_get_upper_bound(unsigned int *hi, const char *key) ++static int applesmc_get_upper_bound(struct applesmc_device *smc, ++ unsigned int *hi, const char *key) + { +- int begin = 0, end = smcreg.key_count; ++ int begin = 0, end = smc->reg.key_count; + const struct applesmc_entry *entry; + + while (begin != end) { + int middle = begin + (end - begin) / 2; +- entry = applesmc_get_entry_by_index(middle); ++ entry = applesmc_get_entry_by_index(smc, middle); + if (IS_ERR(entry)) { +- *hi = smcreg.key_count; ++ *hi = smc->reg.key_count; + return PTR_ERR(entry); + } + if (strcmp(key, entry->key) < 0) +@@ -444,50 +664,54 @@ static int applesmc_get_upper_bound(unsigned int *hi, const char *key) + return 0; + } + +-static const struct applesmc_entry *applesmc_get_entry_by_key(const char *key) ++static const struct applesmc_entry *applesmc_get_entry_by_key( ++ struct applesmc_device *smc, const char *key) + { + int begin, end; + int ret; + +- ret = applesmc_get_lower_bound(&begin, key); ++ ret = applesmc_get_lower_bound(smc, &begin, key); + if (ret) + return ERR_PTR(ret); +- ret = applesmc_get_upper_bound(&end, key); ++ ret = applesmc_get_upper_bound(smc, &end, key); + if (ret) + return ERR_PTR(ret); + if (end - begin != 1) + return ERR_PTR(-EINVAL); + +- return applesmc_get_entry_by_index(begin); ++ return applesmc_get_entry_by_index(smc, begin); + } + +-static int applesmc_read_key(const char *key, u8 *buffer, u8 len) ++static int applesmc_read_key(struct applesmc_device *smc, ++ const char *key, u8 *buffer, u8 len) + { + const struct applesmc_entry *entry; + +- entry = applesmc_get_entry_by_key(key); ++ entry = applesmc_get_entry_by_key(smc, key); + if (IS_ERR(entry)) + return PTR_ERR(entry); + +- return applesmc_read_entry(entry, buffer, len); ++ return applesmc_read_entry(smc, entry, buffer, len); + } + +-static int applesmc_write_key(const char *key, const u8 *buffer, u8 len) ++static int applesmc_write_key(struct applesmc_device *smc, ++ const char *key, const u8 *buffer, u8 len) + { + const struct applesmc_entry *entry; + +- entry = applesmc_get_entry_by_key(key); ++ entry = applesmc_get_entry_by_key(smc, key); + if (IS_ERR(entry)) + return PTR_ERR(entry); + +- return applesmc_write_entry(entry, buffer, len); ++ return applesmc_write_entry(smc, entry, buffer, len); + } + +-static int applesmc_has_key(const char *key, bool *value) ++static int applesmc_has_key(struct applesmc_device *smc, ++ const char *key, bool *value) + { + const struct applesmc_entry *entry; + +- entry = applesmc_get_entry_by_key(key); ++ entry = applesmc_get_entry_by_key(smc, key); + if (IS_ERR(entry) && PTR_ERR(entry) != -EINVAL) + return PTR_ERR(entry); + +@@ -498,12 +722,13 @@ static int applesmc_has_key(const char *key, bool *value) + /* + * applesmc_read_s16 - Read 16-bit signed big endian register + */ +-static int applesmc_read_s16(const char *key, s16 *value) ++static int applesmc_read_s16(struct applesmc_device *smc, ++ const char *key, s16 *value) + { + u8 buffer[2]; + int ret; + +- ret = applesmc_read_key(key, buffer, 2); ++ ret = applesmc_read_key(smc, key, buffer, 2); + if (ret) + return ret; + +@@ -511,31 +736,68 @@ static int applesmc_read_s16(const char *key, s16 *value) + return 0; + } + ++/** ++ * applesmc_float_to_u32 - Retrieve the integral part of a float. ++ * This is needed because Apple made fans use float values in the T2. ++ * The fractional point is not significantly useful though, and the integral ++ * part can be easily extracted. ++ */ ++static inline u32 applesmc_float_to_u32(u32 d) ++{ ++ u8 sign = (u8) ((d >> 31) & 1); ++ s32 exp = (s32) ((d >> 23) & 0xff) - 0x7f; ++ u32 fr = d & ((1u << 23) - 1); ++ ++ if (sign || exp < 0) ++ return 0; ++ ++ return (u32) ((1u << exp) + (fr >> (23 - exp))); ++} ++ ++/** ++ * applesmc_u32_to_float - Convert an u32 into a float. ++ * See applesmc_float_to_u32 for a rationale. ++ */ ++static inline u32 applesmc_u32_to_float(u32 d) ++{ ++ u32 dc = d, bc = 0, exp; ++ ++ if (!d) ++ return 0; ++ ++ while (dc >>= 1) ++ ++bc; ++ exp = 0x7f + bc; ++ ++ return (u32) ((exp << 23) | ++ ((d << (23 - (exp - 0x7f))) & ((1u << 23) - 1))); ++} + /* + * applesmc_device_init - initialize the accelerometer. Can sleep. + */ +-static void applesmc_device_init(void) ++static void applesmc_device_init(struct applesmc_device *smc) + { + int total; + u8 buffer[2]; + +- if (!smcreg.has_accelerometer) ++ if (!smc->reg.has_accelerometer) + return; + + for (total = INIT_TIMEOUT_MSECS; total > 0; total -= INIT_WAIT_MSECS) { +- if (!applesmc_read_key(MOTION_SENSOR_KEY, buffer, 2) && ++ if (!applesmc_read_key(smc, MOTION_SENSOR_KEY, buffer, 2) && + (buffer[0] != 0x00 || buffer[1] != 0x00)) + return; + buffer[0] = 0xe0; + buffer[1] = 0x00; +- applesmc_write_key(MOTION_SENSOR_KEY, buffer, 2); ++ applesmc_write_key(smc, MOTION_SENSOR_KEY, buffer, 2); + msleep(INIT_WAIT_MSECS); + } + + pr_warn("failed to init the device\n"); + } + +-static int applesmc_init_index(struct applesmc_registers *s) ++static int applesmc_init_index(struct applesmc_device *smc, ++ struct applesmc_registers *s) + { + const struct applesmc_entry *entry; + unsigned int i; +@@ -548,7 +810,7 @@ static int applesmc_init_index(struct applesmc_registers *s) + return -ENOMEM; + + for (i = s->temp_begin; i < s->temp_end; i++) { +- entry = applesmc_get_entry_by_index(i); ++ entry = applesmc_get_entry_by_index(smc, i); + if (IS_ERR(entry)) + continue; + if (strcmp(entry->type, TEMP_SENSOR_TYPE)) +@@ -562,9 +824,9 @@ static int applesmc_init_index(struct applesmc_registers *s) + /* + * applesmc_init_smcreg_try - Try to initialize register cache. Idempotent. + */ +-static int applesmc_init_smcreg_try(void) ++static int applesmc_init_smcreg_try(struct applesmc_device *smc) + { +- struct applesmc_registers *s = &smcreg; ++ struct applesmc_registers *s = &smc->reg; + bool left_light_sensor = false, right_light_sensor = false; + unsigned int count; + u8 tmp[1]; +@@ -573,7 +835,7 @@ static int applesmc_init_smcreg_try(void) + if (s->init_complete) + return 0; + +- ret = read_register_count(&count); ++ ret = read_register_count(smc, &count); + if (ret) + return ret; + +@@ -590,35 +852,35 @@ static int applesmc_init_smcreg_try(void) + if (!s->cache) + return -ENOMEM; + +- ret = applesmc_read_key(FANS_COUNT, tmp, 1); ++ ret = applesmc_read_key(smc, FANS_COUNT, tmp, 1); + if (ret) + return ret; + s->fan_count = tmp[0]; + if (s->fan_count > 10) + s->fan_count = 10; + +- ret = applesmc_get_lower_bound(&s->temp_begin, "T"); ++ ret = applesmc_get_lower_bound(smc, &s->temp_begin, "T"); + if (ret) + return ret; +- ret = applesmc_get_lower_bound(&s->temp_end, "U"); ++ ret = applesmc_get_lower_bound(smc, &s->temp_end, "U"); + if (ret) + return ret; + s->temp_count = s->temp_end - s->temp_begin; + +- ret = applesmc_init_index(s); ++ ret = applesmc_init_index(smc, s); + if (ret) + return ret; + +- ret = applesmc_has_key(LIGHT_SENSOR_LEFT_KEY, &left_light_sensor); ++ ret = applesmc_has_key(smc, LIGHT_SENSOR_LEFT_KEY, &left_light_sensor); + if (ret) + return ret; +- ret = applesmc_has_key(LIGHT_SENSOR_RIGHT_KEY, &right_light_sensor); ++ ret = applesmc_has_key(smc, LIGHT_SENSOR_RIGHT_KEY, &right_light_sensor); + if (ret) + return ret; +- ret = applesmc_has_key(MOTION_SENSOR_KEY, &s->has_accelerometer); ++ ret = applesmc_has_key(smc, MOTION_SENSOR_KEY, &s->has_accelerometer); + if (ret) + return ret; +- ret = applesmc_has_key(BACKLIGHT_KEY, &s->has_key_backlight); ++ ret = applesmc_has_key(smc, BACKLIGHT_KEY, &s->has_key_backlight); + if (ret) + return ret; + +@@ -634,13 +896,13 @@ static int applesmc_init_smcreg_try(void) + return 0; + } + +-static void applesmc_destroy_smcreg(void) ++static void applesmc_destroy_smcreg(struct applesmc_device *smc) + { +- kfree(smcreg.index); +- smcreg.index = NULL; +- kfree(smcreg.cache); +- smcreg.cache = NULL; +- smcreg.init_complete = false; ++ kfree(smc->reg.index); ++ smc->reg.index = NULL; ++ kfree(smc->reg.cache); ++ smc->reg.cache = NULL; ++ smc->reg.init_complete = false; + } + + /* +@@ -649,12 +911,12 @@ static void applesmc_destroy_smcreg(void) + * Retries until initialization is successful, or the operation times out. + * + */ +-static int applesmc_init_smcreg(void) ++static int applesmc_init_smcreg(struct applesmc_device *smc) + { + int ms, ret; + + for (ms = 0; ms < INIT_TIMEOUT_MSECS; ms += INIT_WAIT_MSECS) { +- ret = applesmc_init_smcreg_try(); ++ ret = applesmc_init_smcreg_try(smc); + if (!ret) { + if (ms) + pr_info("init_smcreg() took %d ms\n", ms); +@@ -663,50 +925,223 @@ static int applesmc_init_smcreg(void) + msleep(INIT_WAIT_MSECS); + } + +- applesmc_destroy_smcreg(); ++ applesmc_destroy_smcreg(smc); + + return ret; + } + + /* Device model stuff */ +-static int applesmc_probe(struct platform_device *dev) ++ ++static int applesmc_init_resources(struct applesmc_device *smc); ++static void applesmc_free_resources(struct applesmc_device *smc); ++static int applesmc_create_modules(struct applesmc_device *smc); ++static void applesmc_destroy_modules(struct applesmc_device *smc); ++ ++static int applesmc_add(struct acpi_device *dev) + { ++ struct applesmc_device *smc; + int ret; + +- ret = applesmc_init_smcreg(); ++ smc = kzalloc(sizeof(struct applesmc_device), GFP_KERNEL); ++ if (!smc) ++ return -ENOMEM; ++ smc->dev = dev; ++ smc->ldev = &dev->dev; ++ mutex_init(&smc->reg.mutex); ++ ++ dev_set_drvdata(&dev->dev, smc); ++ ++ ret = applesmc_init_resources(smc); + if (ret) +- return ret; ++ goto out_mem; ++ ++ ret = applesmc_init_smcreg(smc); ++ if (ret) ++ goto out_res; ++ ++ applesmc_device_init(smc); ++ ++ ret = applesmc_create_modules(smc); ++ if (ret) ++ goto out_reg; ++ ++ return 0; ++ ++out_reg: ++ applesmc_destroy_smcreg(smc); ++out_res: ++ applesmc_free_resources(smc); ++out_mem: ++ dev_set_drvdata(&dev->dev, NULL); ++ mutex_destroy(&smc->reg.mutex); ++ kfree(smc); ++ ++ return ret; ++} ++ ++static void applesmc_remove(struct acpi_device *dev) ++{ ++ struct applesmc_device *smc = dev_get_drvdata(&dev->dev); ++ ++ applesmc_destroy_modules(smc); ++ applesmc_destroy_smcreg(smc); ++ applesmc_free_resources(smc); + +- applesmc_device_init(); ++ mutex_destroy(&smc->reg.mutex); ++ kfree(smc); ++ ++ return; ++} ++ ++static acpi_status applesmc_walk_resources(struct acpi_resource *res, ++ void *data) ++{ ++ struct applesmc_device *smc = data; ++ ++ switch (res->type) { ++ case ACPI_RESOURCE_TYPE_IO: ++ if (!smc->port_base_set) { ++ if (res->data.io.address_length < APPLESMC_NR_PORTS) ++ return AE_ERROR; ++ smc->port_base = res->data.io.minimum; ++ smc->port_base_set = true; ++ } ++ return AE_OK; ++ ++ case ACPI_RESOURCE_TYPE_FIXED_MEMORY32: ++ if (!smc->iomem_base_set) { ++ if (res->data.fixed_memory32.address_length < ++ APPLESMC_IOMEM_MIN_SIZE) { ++ dev_warn(smc->ldev, "found iomem but it's too small: %u\n", ++ res->data.fixed_memory32.address_length); ++ return AE_OK; ++ } ++ smc->iomem_base_addr = res->data.fixed_memory32.address; ++ smc->iomem_base_size = res->data.fixed_memory32.address_length; ++ smc->iomem_base_set = true; ++ } ++ return AE_OK; ++ ++ case ACPI_RESOURCE_TYPE_END_TAG: ++ if (smc->port_base_set) ++ return AE_OK; ++ else ++ return AE_NOT_FOUND; ++ ++ default: ++ return AE_OK; ++ } ++} ++ ++static int applesmc_try_enable_iomem(struct applesmc_device *smc); ++ ++static int applesmc_init_resources(struct applesmc_device *smc) ++{ ++ int ret; ++ ++ ret = acpi_walk_resources(smc->dev->handle, METHOD_NAME__CRS, ++ applesmc_walk_resources, smc); ++ if (ACPI_FAILURE(ret)) ++ return -ENXIO; ++ ++ if (!request_region(smc->port_base, APPLESMC_NR_PORTS, "applesmc")) ++ return -ENXIO; ++ ++ if (smc->iomem_base_set) { ++ if (applesmc_try_enable_iomem(smc)) ++ smc->iomem_base_set = false; ++ } ++ ++ return 0; ++} ++ ++static int applesmc_try_enable_iomem(struct applesmc_device *smc) ++{ ++ u8 test_val, ldkn_version; ++ ++ dev_dbg(smc->ldev, "Trying to enable iomem based communication\n"); ++ smc->iomem_base = ioremap(smc->iomem_base_addr, smc->iomem_base_size); ++ if (!smc->iomem_base) ++ goto out; ++ ++ /* Apple's driver does this check for some reason */ ++ test_val = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); ++ if (test_val == 0xff) { ++ dev_warn(smc->ldev, ++ "iomem enable failed: initial status is 0xff (is %x)\n", ++ test_val); ++ goto out_iomem; ++ } ++ ++ if (read_smc(smc, "LDKN", &ldkn_version, 1)) { ++ dev_warn(smc->ldev, "iomem enable failed: ldkn read failed\n"); ++ goto out_iomem; ++ } ++ ++ if (ldkn_version < 2) { ++ dev_warn(smc->ldev, ++ "iomem enable failed: ldkn version %u is less than minimum (2)\n", ++ ldkn_version); ++ goto out_iomem; ++ } + + return 0; ++ ++out_iomem: ++ iounmap(smc->iomem_base); ++ ++out: ++ return -ENXIO; ++} ++ ++static void applesmc_free_resources(struct applesmc_device *smc) ++{ ++ if (smc->iomem_base_set) ++ iounmap(smc->iomem_base); ++ release_region(smc->port_base, APPLESMC_NR_PORTS); + } + + /* Synchronize device with memorized backlight state */ + static int applesmc_pm_resume(struct device *dev) + { +- if (smcreg.has_key_backlight) +- applesmc_write_key(BACKLIGHT_KEY, backlight_state, 2); ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ ++ if (smc->reg.has_key_backlight) ++ applesmc_write_key(smc, BACKLIGHT_KEY, smc->backlight_state, 2); ++ + return 0; + } + + /* Reinitialize device on resume from hibernation */ + static int applesmc_pm_restore(struct device *dev) + { +- applesmc_device_init(); ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ ++ applesmc_device_init(smc); ++ + return applesmc_pm_resume(dev); + } + ++static const struct acpi_device_id applesmc_ids[] = { ++ {"APP0001", 0}, ++ {"", 0}, ++}; ++ + static const struct dev_pm_ops applesmc_pm_ops = { + .resume = applesmc_pm_resume, + .restore = applesmc_pm_restore, + }; + +-static struct platform_driver applesmc_driver = { +- .probe = applesmc_probe, +- .driver = { +- .name = "applesmc", +- .pm = &applesmc_pm_ops, ++static struct acpi_driver applesmc_driver = { ++ .name = "applesmc", ++ .class = "applesmc", ++ .ids = applesmc_ids, ++ .ops = { ++ .add = applesmc_add, ++ .remove = applesmc_remove ++ }, ++ .drv = { ++ .pm = &applesmc_pm_ops + }, + }; + +@@ -714,25 +1149,26 @@ static struct platform_driver applesmc_driver = { + * applesmc_calibrate - Set our "resting" values. Callers must + * hold applesmc_lock. + */ +-static void applesmc_calibrate(void) ++static void applesmc_calibrate(struct applesmc_device *smc) + { +- applesmc_read_s16(MOTION_SENSOR_X_KEY, &rest_x); +- applesmc_read_s16(MOTION_SENSOR_Y_KEY, &rest_y); +- rest_x = -rest_x; ++ applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &smc->rest_x); ++ applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &smc->rest_y); ++ smc->rest_x = -smc->rest_x; + } + + static void applesmc_idev_poll(struct input_dev *idev) + { ++ struct applesmc_device *smc = dev_get_drvdata(&idev->dev); + s16 x, y; + +- if (applesmc_read_s16(MOTION_SENSOR_X_KEY, &x)) ++ if (applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &x)) + return; +- if (applesmc_read_s16(MOTION_SENSOR_Y_KEY, &y)) ++ if (applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &y)) + return; + + x = -x; +- input_report_abs(idev, ABS_X, x - rest_x); +- input_report_abs(idev, ABS_Y, y - rest_y); ++ input_report_abs(idev, ABS_X, x - smc->rest_x); ++ input_report_abs(idev, ABS_Y, y - smc->rest_y); + input_sync(idev); + } + +@@ -747,16 +1183,17 @@ static ssize_t applesmc_name_show(struct device *dev, + static ssize_t applesmc_position_show(struct device *dev, + struct device_attribute *attr, char *buf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + int ret; + s16 x, y, z; + +- ret = applesmc_read_s16(MOTION_SENSOR_X_KEY, &x); ++ ret = applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &x); + if (ret) + goto out; +- ret = applesmc_read_s16(MOTION_SENSOR_Y_KEY, &y); ++ ret = applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &y); + if (ret) + goto out; +- ret = applesmc_read_s16(MOTION_SENSOR_Z_KEY, &z); ++ ret = applesmc_read_s16(smc, MOTION_SENSOR_Z_KEY, &z); + if (ret) + goto out; + +@@ -770,6 +1207,7 @@ static ssize_t applesmc_position_show(struct device *dev, + static ssize_t applesmc_light_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; + static int data_length; + int ret; +@@ -777,7 +1215,7 @@ static ssize_t applesmc_light_show(struct device *dev, + u8 buffer[10]; + + if (!data_length) { +- entry = applesmc_get_entry_by_key(LIGHT_SENSOR_LEFT_KEY); ++ entry = applesmc_get_entry_by_key(smc, LIGHT_SENSOR_LEFT_KEY); + if (IS_ERR(entry)) + return PTR_ERR(entry); + if (entry->len > 10) +@@ -786,7 +1224,7 @@ static ssize_t applesmc_light_show(struct device *dev, + pr_info("light sensor data length set to %d\n", data_length); + } + +- ret = applesmc_read_key(LIGHT_SENSOR_LEFT_KEY, buffer, data_length); ++ ret = applesmc_read_key(smc, LIGHT_SENSOR_LEFT_KEY, buffer, data_length); + if (ret) + goto out; + /* newer macbooks report a single 10-bit bigendian value */ +@@ -796,7 +1234,7 @@ static ssize_t applesmc_light_show(struct device *dev, + } + left = buffer[2]; + +- ret = applesmc_read_key(LIGHT_SENSOR_RIGHT_KEY, buffer, data_length); ++ ret = applesmc_read_key(smc, LIGHT_SENSOR_RIGHT_KEY, buffer, data_length); + if (ret) + goto out; + right = buffer[2]; +@@ -812,7 +1250,8 @@ static ssize_t applesmc_light_show(struct device *dev, + static ssize_t applesmc_show_sensor_label(struct device *dev, + struct device_attribute *devattr, char *sysfsbuf) + { +- const char *key = smcreg.index[to_index(devattr)]; ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ const char *key = smc->reg.index[to_index(devattr)]; + + return sysfs_emit(sysfsbuf, "%s\n", key); + } +@@ -821,12 +1260,13 @@ static ssize_t applesmc_show_sensor_label(struct device *dev, + static ssize_t applesmc_show_temperature(struct device *dev, + struct device_attribute *devattr, char *sysfsbuf) + { +- const char *key = smcreg.index[to_index(devattr)]; ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ const char *key = smc->reg.index[to_index(devattr)]; + int ret; + s16 value; + int temp; + +- ret = applesmc_read_s16(key, &value); ++ ret = applesmc_read_s16(smc, key, &value); + if (ret) + return ret; + +@@ -838,6 +1278,8 @@ static ssize_t applesmc_show_temperature(struct device *dev, + static ssize_t applesmc_show_fan_speed(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ const struct applesmc_entry *entry; + int ret; + unsigned int speed = 0; + char newkey[5]; +@@ -846,11 +1288,21 @@ static ssize_t applesmc_show_fan_speed(struct device *dev, + scnprintf(newkey, sizeof(newkey), fan_speed_fmt[to_option(attr)], + to_index(attr)); + +- ret = applesmc_read_key(newkey, buffer, 2); ++ entry = applesmc_get_entry_by_key(smc, newkey); ++ if (IS_ERR(entry)) ++ return PTR_ERR(entry); ++ ++ if (!strcmp(entry->type, FLOAT_TYPE)) { ++ ret = applesmc_read_entry(smc, entry, (u8 *) &speed, 4); ++ speed = applesmc_float_to_u32(speed); ++ } else { ++ ret = applesmc_read_entry(smc, entry, buffer, 2); ++ speed = ((buffer[0] << 8 | buffer[1]) >> 2); ++ } ++ + if (ret) + return ret; + +- speed = ((buffer[0] << 8 | buffer[1]) >> 2); + return sysfs_emit(sysfsbuf, "%u\n", speed); + } + +@@ -858,6 +1310,8 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, + struct device_attribute *attr, + const char *sysfsbuf, size_t count) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ const struct applesmc_entry *entry; + int ret; + unsigned long speed; + char newkey[5]; +@@ -869,9 +1323,18 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, + scnprintf(newkey, sizeof(newkey), fan_speed_fmt[to_option(attr)], + to_index(attr)); + +- buffer[0] = (speed >> 6) & 0xff; +- buffer[1] = (speed << 2) & 0xff; +- ret = applesmc_write_key(newkey, buffer, 2); ++ entry = applesmc_get_entry_by_key(smc, newkey); ++ if (IS_ERR(entry)) ++ return PTR_ERR(entry); ++ ++ if (!strcmp(entry->type, FLOAT_TYPE)) { ++ speed = applesmc_u32_to_float(speed); ++ ret = applesmc_write_entry(smc, entry, (u8 *) &speed, 4); ++ } else { ++ buffer[0] = (speed >> 6) & 0xff; ++ buffer[1] = (speed << 2) & 0xff; ++ ret = applesmc_write_key(smc, newkey, buffer, 2); ++ } + + if (ret) + return ret; +@@ -882,15 +1345,30 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, + static ssize_t applesmc_show_fan_manual(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + int ret; + u16 manual = 0; + u8 buffer[2]; ++ char newkey[5]; ++ bool has_newkey = false; ++ ++ scnprintf(newkey, sizeof(newkey), FAN_MANUAL_FMT, to_index(attr)); ++ ++ ret = applesmc_has_key(smc, newkey, &has_newkey); ++ if (ret) ++ return ret; ++ ++ if (has_newkey) { ++ ret = applesmc_read_key(smc, newkey, buffer, 1); ++ manual = buffer[0]; ++ } else { ++ ret = applesmc_read_key(smc, FANS_MANUAL, buffer, 2); ++ manual = ((buffer[0] << 8 | buffer[1]) >> to_index(attr)) & 0x01; ++ } + +- ret = applesmc_read_key(FANS_MANUAL, buffer, 2); + if (ret) + return ret; + +- manual = ((buffer[0] << 8 | buffer[1]) >> to_index(attr)) & 0x01; + return sysfs_emit(sysfsbuf, "%d\n", manual); + } + +@@ -898,29 +1376,42 @@ static ssize_t applesmc_store_fan_manual(struct device *dev, + struct device_attribute *attr, + const char *sysfsbuf, size_t count) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + int ret; + u8 buffer[2]; ++ char newkey[5]; ++ bool has_newkey = false; + unsigned long input; + u16 val; + + if (kstrtoul(sysfsbuf, 10, &input) < 0) + return -EINVAL; + +- ret = applesmc_read_key(FANS_MANUAL, buffer, 2); ++ scnprintf(newkey, sizeof(newkey), FAN_MANUAL_FMT, to_index(attr)); ++ ++ ret = applesmc_has_key(smc, newkey, &has_newkey); + if (ret) +- goto out; ++ return ret; + +- val = (buffer[0] << 8 | buffer[1]); ++ if (has_newkey) { ++ buffer[0] = input & 1; ++ ret = applesmc_write_key(smc, newkey, buffer, 1); ++ } else { ++ ret = applesmc_read_key(smc, FANS_MANUAL, buffer, 2); ++ val = (buffer[0] << 8 | buffer[1]); ++ if (ret) ++ goto out; + +- if (input) +- val = val | (0x01 << to_index(attr)); +- else +- val = val & ~(0x01 << to_index(attr)); ++ if (input) ++ val = val | (0x01 << to_index(attr)); ++ else ++ val = val & ~(0x01 << to_index(attr)); + +- buffer[0] = (val >> 8) & 0xFF; +- buffer[1] = val & 0xFF; ++ buffer[0] = (val >> 8) & 0xFF; ++ buffer[1] = val & 0xFF; + +- ret = applesmc_write_key(FANS_MANUAL, buffer, 2); ++ ret = applesmc_write_key(smc, FANS_MANUAL, buffer, 2); ++ } + + out: + if (ret) +@@ -932,13 +1423,14 @@ static ssize_t applesmc_store_fan_manual(struct device *dev, + static ssize_t applesmc_show_fan_position(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + int ret; + char newkey[5]; + u8 buffer[17]; + + scnprintf(newkey, sizeof(newkey), FAN_ID_FMT, to_index(attr)); + +- ret = applesmc_read_key(newkey, buffer, 16); ++ ret = applesmc_read_key(smc, newkey, buffer, 16); + buffer[16] = 0; + + if (ret) +@@ -950,43 +1442,79 @@ static ssize_t applesmc_show_fan_position(struct device *dev, + static ssize_t applesmc_calibrate_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { +- return sysfs_emit(sysfsbuf, "(%d,%d)\n", rest_x, rest_y); ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ ++ return sysfs_emit(sysfsbuf, "(%d,%d)\n", smc->rest_x, smc->rest_y); + } + + static ssize_t applesmc_calibrate_store(struct device *dev, + struct device_attribute *attr, const char *sysfsbuf, size_t count) + { +- applesmc_calibrate(); ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ ++ applesmc_calibrate(smc); + + return count; + } + + static void applesmc_backlight_set(struct work_struct *work) + { +- applesmc_write_key(BACKLIGHT_KEY, backlight_state, 2); ++ struct applesmc_device *smc = container_of(work, struct applesmc_device, backlight_work); ++ ++ applesmc_write_key(smc, BACKLIGHT_KEY, smc->backlight_state, 2); + } +-static DECLARE_WORK(backlight_work, &applesmc_backlight_set); + + static void applesmc_brightness_set(struct led_classdev *led_cdev, + enum led_brightness value) + { ++ struct applesmc_device *smc = dev_get_drvdata(led_cdev->dev); + int ret; + +- backlight_state[0] = value; +- ret = queue_work(applesmc_led_wq, &backlight_work); ++ smc->backlight_state[0] = value; ++ ret = queue_work(smc->backlight_wq, &smc->backlight_work); + + if (debug && (!ret)) + dev_dbg(led_cdev->dev, "work was already on the queue.\n"); + } + ++static ssize_t applesmc_BCLM_store(struct device *dev, ++ struct device_attribute *attr, char *sysfsbuf, size_t count) ++{ ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ u8 val; ++ ++ if (kstrtou8(sysfsbuf, 10, &val) < 0) ++ return -EINVAL; ++ ++ if (val < 0 || val > 100) ++ return -EINVAL; ++ ++ if (applesmc_write_key(smc, "BCLM", &val, 1)) ++ return -ENODEV; ++ return count; ++} ++ ++static ssize_t applesmc_BCLM_show(struct device *dev, ++ struct device_attribute *attr, char *sysfsbuf) ++{ ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ u8 val; ++ ++ if (applesmc_read_key(smc, "BCLM", &val, 1)) ++ return -ENODEV; ++ ++ return sysfs_emit(sysfsbuf, "%d\n", val); ++} ++ + static ssize_t applesmc_key_count_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + int ret; + u8 buffer[4]; + u32 count; + +- ret = applesmc_read_key(KEY_COUNT_KEY, buffer, 4); ++ ret = applesmc_read_key(smc, KEY_COUNT_KEY, buffer, 4); + if (ret) + return ret; + +@@ -998,13 +1526,14 @@ static ssize_t applesmc_key_count_show(struct device *dev, + static ssize_t applesmc_key_at_index_read_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; + int ret; + +- entry = applesmc_get_entry_by_index(key_at_index); ++ entry = applesmc_get_entry_by_index(smc, smc->key_at_index); + if (IS_ERR(entry)) + return PTR_ERR(entry); +- ret = applesmc_read_entry(entry, sysfsbuf, entry->len); ++ ret = applesmc_read_entry(smc, entry, sysfsbuf, entry->len); + if (ret) + return ret; + +@@ -1014,9 +1543,10 @@ static ssize_t applesmc_key_at_index_read_show(struct device *dev, + static ssize_t applesmc_key_at_index_data_length_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; + +- entry = applesmc_get_entry_by_index(key_at_index); ++ entry = applesmc_get_entry_by_index(smc, smc->key_at_index); + if (IS_ERR(entry)) + return PTR_ERR(entry); + +@@ -1026,9 +1556,10 @@ static ssize_t applesmc_key_at_index_data_length_show(struct device *dev, + static ssize_t applesmc_key_at_index_type_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; + +- entry = applesmc_get_entry_by_index(key_at_index); ++ entry = applesmc_get_entry_by_index(smc, smc->key_at_index); + if (IS_ERR(entry)) + return PTR_ERR(entry); + +@@ -1038,9 +1569,10 @@ static ssize_t applesmc_key_at_index_type_show(struct device *dev, + static ssize_t applesmc_key_at_index_name_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; + +- entry = applesmc_get_entry_by_index(key_at_index); ++ entry = applesmc_get_entry_by_index(smc, smc->key_at_index); + if (IS_ERR(entry)) + return PTR_ERR(entry); + +@@ -1050,28 +1582,25 @@ static ssize_t applesmc_key_at_index_name_show(struct device *dev, + static ssize_t applesmc_key_at_index_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) + { +- return sysfs_emit(sysfsbuf, "%d\n", key_at_index); ++ struct applesmc_device *smc = dev_get_drvdata(dev); ++ ++ return sysfs_emit(sysfsbuf, "%d\n", smc->key_at_index); + } + + static ssize_t applesmc_key_at_index_store(struct device *dev, + struct device_attribute *attr, const char *sysfsbuf, size_t count) + { ++ struct applesmc_device *smc = dev_get_drvdata(dev); + unsigned long newkey; + + if (kstrtoul(sysfsbuf, 10, &newkey) < 0 +- || newkey >= smcreg.key_count) ++ || newkey >= smc->reg.key_count) + return -EINVAL; + +- key_at_index = newkey; ++ smc->key_at_index = newkey; + return count; + } + +-static struct led_classdev applesmc_backlight = { +- .name = "smc::kbd_backlight", +- .default_trigger = "nand-disk", +- .brightness_set = applesmc_brightness_set, +-}; +- + static struct applesmc_node_group info_group[] = { + { "name", applesmc_name_show }, + { "key_count", applesmc_key_count_show }, +@@ -1111,19 +1640,25 @@ static struct applesmc_node_group temp_group[] = { + { } + }; + ++static struct applesmc_node_group BCLM_group[] = { ++ { "battery_charge_limit", applesmc_BCLM_show, applesmc_BCLM_store }, ++ { } ++}; ++ + /* Module stuff */ + + /* + * applesmc_destroy_nodes - remove files and free associated memory + */ +-static void applesmc_destroy_nodes(struct applesmc_node_group *groups) ++static void applesmc_destroy_nodes(struct applesmc_device *smc, ++ struct applesmc_node_group *groups) + { + struct applesmc_node_group *grp; + struct applesmc_dev_attr *node; + + for (grp = groups; grp->nodes; grp++) { + for (node = grp->nodes; node->sda.dev_attr.attr.name; node++) +- sysfs_remove_file(&pdev->dev.kobj, ++ sysfs_remove_file(&smc->dev->dev.kobj, + &node->sda.dev_attr.attr); + kfree(grp->nodes); + grp->nodes = NULL; +@@ -1133,7 +1668,8 @@ static void applesmc_destroy_nodes(struct applesmc_node_group *groups) + /* + * applesmc_create_nodes - create a two-dimensional group of sysfs files + */ +-static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) ++static int applesmc_create_nodes(struct applesmc_device *smc, ++ struct applesmc_node_group *groups, int num) + { + struct applesmc_node_group *grp; + struct applesmc_dev_attr *node; +@@ -1157,7 +1693,7 @@ static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) + sysfs_attr_init(attr); + attr->name = node->name; + attr->mode = 0444 | (grp->store ? 0200 : 0); +- ret = sysfs_create_file(&pdev->dev.kobj, attr); ++ ret = sysfs_create_file(&smc->dev->dev.kobj, attr); + if (ret) { + attr->name = NULL; + goto out; +@@ -1167,57 +1703,56 @@ static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) + + return 0; + out: +- applesmc_destroy_nodes(groups); ++ applesmc_destroy_nodes(smc, groups); + return ret; + } + + /* Create accelerometer resources */ +-static int applesmc_create_accelerometer(void) ++static int applesmc_create_accelerometer(struct applesmc_device *smc) + { + int ret; +- +- if (!smcreg.has_accelerometer) ++ if (!smc->reg.has_accelerometer) + return 0; + +- ret = applesmc_create_nodes(accelerometer_group, 1); ++ ret = applesmc_create_nodes(smc, accelerometer_group, 1); + if (ret) + goto out; + +- applesmc_idev = input_allocate_device(); +- if (!applesmc_idev) { ++ smc->idev = input_allocate_device(); ++ if (!smc->idev) { + ret = -ENOMEM; + goto out_sysfs; + } + + /* initial calibrate for the input device */ +- applesmc_calibrate(); ++ applesmc_calibrate(smc); + + /* initialize the input device */ +- applesmc_idev->name = "applesmc"; +- applesmc_idev->id.bustype = BUS_HOST; +- applesmc_idev->dev.parent = &pdev->dev; +- input_set_abs_params(applesmc_idev, ABS_X, ++ smc->idev->name = "applesmc"; ++ smc->idev->id.bustype = BUS_HOST; ++ smc->idev->dev.parent = &smc->dev->dev; ++ input_set_abs_params(smc->idev, ABS_X, + -256, 256, APPLESMC_INPUT_FUZZ, APPLESMC_INPUT_FLAT); +- input_set_abs_params(applesmc_idev, ABS_Y, ++ input_set_abs_params(smc->idev, ABS_Y, + -256, 256, APPLESMC_INPUT_FUZZ, APPLESMC_INPUT_FLAT); + +- ret = input_setup_polling(applesmc_idev, applesmc_idev_poll); ++ ret = input_setup_polling(smc->idev, applesmc_idev_poll); + if (ret) + goto out_idev; + +- input_set_poll_interval(applesmc_idev, APPLESMC_POLL_INTERVAL); ++ input_set_poll_interval(smc->idev, APPLESMC_POLL_INTERVAL); + +- ret = input_register_device(applesmc_idev); ++ ret = input_register_device(smc->idev); + if (ret) + goto out_idev; + + return 0; + + out_idev: +- input_free_device(applesmc_idev); ++ input_free_device(smc->idev); + + out_sysfs: +- applesmc_destroy_nodes(accelerometer_group); ++ applesmc_destroy_nodes(smc, accelerometer_group); + + out: + pr_warn("driver init failed (ret=%d)!\n", ret); +@@ -1225,44 +1760,55 @@ static int applesmc_create_accelerometer(void) + } + + /* Release all resources used by the accelerometer */ +-static void applesmc_release_accelerometer(void) ++static void applesmc_release_accelerometer(struct applesmc_device *smc) + { +- if (!smcreg.has_accelerometer) ++ if (!smc->reg.has_accelerometer) + return; +- input_unregister_device(applesmc_idev); +- applesmc_destroy_nodes(accelerometer_group); ++ input_unregister_device(smc->idev); ++ applesmc_destroy_nodes(smc, accelerometer_group); + } + +-static int applesmc_create_light_sensor(void) ++static int applesmc_create_light_sensor(struct applesmc_device *smc) + { +- if (!smcreg.num_light_sensors) ++ if (!smc->reg.num_light_sensors) + return 0; +- return applesmc_create_nodes(light_sensor_group, 1); ++ return applesmc_create_nodes(smc, light_sensor_group, 1); + } + +-static void applesmc_release_light_sensor(void) ++static void applesmc_release_light_sensor(struct applesmc_device *smc) + { +- if (!smcreg.num_light_sensors) ++ if (!smc->reg.num_light_sensors) + return; +- applesmc_destroy_nodes(light_sensor_group); ++ applesmc_destroy_nodes(smc, light_sensor_group); + } + +-static int applesmc_create_key_backlight(void) ++static int applesmc_create_key_backlight(struct applesmc_device *smc) + { +- if (!smcreg.has_key_backlight) ++ int ret; ++ ++ if (!smc->reg.has_key_backlight) + return 0; +- applesmc_led_wq = create_singlethread_workqueue("applesmc-led"); +- if (!applesmc_led_wq) ++ smc->backlight_wq = create_singlethread_workqueue("applesmc-led"); ++ if (!smc->backlight_wq) + return -ENOMEM; +- return led_classdev_register(&pdev->dev, &applesmc_backlight); ++ ++ INIT_WORK(&smc->backlight_work, applesmc_backlight_set); ++ smc->backlight_dev.name = "smc::kbd_backlight"; ++ smc->backlight_dev.default_trigger = "nand-disk"; ++ smc->backlight_dev.brightness_set = applesmc_brightness_set; ++ ret = led_classdev_register(&smc->dev->dev, &smc->backlight_dev); ++ if (ret) ++ destroy_workqueue(smc->backlight_wq); ++ ++ return ret; + } + +-static void applesmc_release_key_backlight(void) ++static void applesmc_release_key_backlight(struct applesmc_device *smc) + { +- if (!smcreg.has_key_backlight) ++ if (!smc->reg.has_key_backlight) + return; +- led_classdev_unregister(&applesmc_backlight); +- destroy_workqueue(applesmc_led_wq); ++ led_classdev_unregister(&smc->backlight_dev); ++ destroy_workqueue(smc->backlight_wq); + } + + static int applesmc_dmi_match(const struct dmi_system_id *id) +@@ -1291,6 +1837,10 @@ static const struct dmi_system_id applesmc_whitelist[] __initconst = { + DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), + DMI_MATCH(DMI_PRODUCT_NAME, "Macmini") }, + }, ++ { applesmc_dmi_match, "Apple iMacPro", { ++ DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), ++ DMI_MATCH(DMI_PRODUCT_NAME, "iMacPro") }, ++ }, + { applesmc_dmi_match, "Apple MacPro", { + DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), + DMI_MATCH(DMI_PRODUCT_NAME, "MacPro") }, +@@ -1306,90 +1856,91 @@ static const struct dmi_system_id applesmc_whitelist[] __initconst = { + { .ident = NULL } + }; + +-static int __init applesmc_init(void) ++static int applesmc_create_modules(struct applesmc_device *smc) + { + int ret; + +- if (!dmi_check_system(applesmc_whitelist)) { +- pr_warn("supported laptop not found!\n"); +- ret = -ENODEV; +- goto out; +- } +- +- if (!request_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS, +- "applesmc")) { +- ret = -ENXIO; +- goto out; +- } +- +- ret = platform_driver_register(&applesmc_driver); +- if (ret) +- goto out_region; +- +- pdev = platform_device_register_simple("applesmc", APPLESMC_DATA_PORT, +- NULL, 0); +- if (IS_ERR(pdev)) { +- ret = PTR_ERR(pdev); +- goto out_driver; +- } +- +- /* create register cache */ +- ret = applesmc_init_smcreg(); ++ ret = applesmc_create_nodes(smc, info_group, 1); + if (ret) +- goto out_device; +- +- ret = applesmc_create_nodes(info_group, 1); ++ goto out; ++ ret = applesmc_create_nodes(smc, BCLM_group, 1); + if (ret) +- goto out_smcreg; ++ goto out_info; + +- ret = applesmc_create_nodes(fan_group, smcreg.fan_count); ++ ret = applesmc_create_nodes(smc, fan_group, smc->reg.fan_count); + if (ret) +- goto out_info; ++ goto out_bclm; + +- ret = applesmc_create_nodes(temp_group, smcreg.index_count); ++ ret = applesmc_create_nodes(smc, temp_group, smc->reg.index_count); + if (ret) + goto out_fans; + +- ret = applesmc_create_accelerometer(); ++ ret = applesmc_create_accelerometer(smc); + if (ret) + goto out_temperature; + +- ret = applesmc_create_light_sensor(); ++ ret = applesmc_create_light_sensor(smc); + if (ret) + goto out_accelerometer; + +- ret = applesmc_create_key_backlight(); ++ ret = applesmc_create_key_backlight(smc); + if (ret) + goto out_light_sysfs; + +- hwmon_dev = hwmon_device_register(&pdev->dev); +- if (IS_ERR(hwmon_dev)) { +- ret = PTR_ERR(hwmon_dev); ++ smc->hwmon_dev = hwmon_device_register(&smc->dev->dev); ++ if (IS_ERR(smc->hwmon_dev)) { ++ ret = PTR_ERR(smc->hwmon_dev); + goto out_light_ledclass; + } + + return 0; + + out_light_ledclass: +- applesmc_release_key_backlight(); ++ applesmc_release_key_backlight(smc); + out_light_sysfs: +- applesmc_release_light_sensor(); ++ applesmc_release_light_sensor(smc); + out_accelerometer: +- applesmc_release_accelerometer(); ++ applesmc_release_accelerometer(smc); + out_temperature: +- applesmc_destroy_nodes(temp_group); ++ applesmc_destroy_nodes(smc, temp_group); + out_fans: +- applesmc_destroy_nodes(fan_group); ++ applesmc_destroy_nodes(smc, fan_group); ++out_bclm: ++ applesmc_destroy_nodes(smc, BCLM_group); + out_info: +- applesmc_destroy_nodes(info_group); +-out_smcreg: +- applesmc_destroy_smcreg(); +-out_device: +- platform_device_unregister(pdev); +-out_driver: +- platform_driver_unregister(&applesmc_driver); +-out_region: +- release_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS); ++ applesmc_destroy_nodes(smc, info_group); ++out: ++ return ret; ++} ++ ++static void applesmc_destroy_modules(struct applesmc_device *smc) ++{ ++ hwmon_device_unregister(smc->hwmon_dev); ++ applesmc_release_key_backlight(smc); ++ applesmc_release_light_sensor(smc); ++ applesmc_release_accelerometer(smc); ++ applesmc_destroy_nodes(smc, temp_group); ++ applesmc_destroy_nodes(smc, fan_group); ++ applesmc_destroy_nodes(smc, BCLM_group); ++ applesmc_destroy_nodes(smc, info_group); ++} ++ ++static int __init applesmc_init(void) ++{ ++ int ret; ++ ++ if (!dmi_check_system(applesmc_whitelist)) { ++ pr_warn("supported laptop not found!\n"); ++ ret = -ENODEV; ++ goto out; ++ } ++ ++ ret = acpi_bus_register_driver(&applesmc_driver); ++ if (ret) ++ goto out; ++ ++ return 0; ++ + out: + pr_warn("driver init failed (ret=%d)!\n", ret); + return ret; +@@ -1397,23 +1948,14 @@ static int __init applesmc_init(void) + + static void __exit applesmc_exit(void) + { +- hwmon_device_unregister(hwmon_dev); +- applesmc_release_key_backlight(); +- applesmc_release_light_sensor(); +- applesmc_release_accelerometer(); +- applesmc_destroy_nodes(temp_group); +- applesmc_destroy_nodes(fan_group); +- applesmc_destroy_nodes(info_group); +- applesmc_destroy_smcreg(); +- platform_device_unregister(pdev); +- platform_driver_unregister(&applesmc_driver); +- release_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS); ++ acpi_bus_unregister_driver(&applesmc_driver); + } + + module_init(applesmc_init); + module_exit(applesmc_exit); + + MODULE_AUTHOR("Nicolas Boichat"); ++MODULE_AUTHOR("Paul Pawlowski"); + MODULE_DESCRIPTION("Apple SMC"); + MODULE_LICENSE("GPL v2"); + MODULE_DEVICE_TABLE(dmi, applesmc_whitelist); +diff --git a/drivers/input/mouse/bcm5974.c b/drivers/input/mouse/bcm5974.c +index ca150618d32f..4e692b272ae9 100644 +--- a/drivers/input/mouse/bcm5974.c ++++ b/drivers/input/mouse/bcm5974.c +@@ -83,6 +83,24 @@ + #define USB_DEVICE_ID_APPLE_WELLSPRING9_ISO 0x0273 + #define USB_DEVICE_ID_APPLE_WELLSPRING9_JIS 0x0274 + ++/* T2-Attached Devices */ ++/* MacbookAir8,1 (2018) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K 0x027a ++/* MacbookPro15,2 (2018) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132 0x027b ++/* MacbookPro15,1 (2018) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680 0x027c ++/* MacbookPro15,4 (2019) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213 0x027d ++/* MacbookPro16,2 (2020) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K 0x027e ++/* MacbookPro16,3 (2020) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223 0x027f ++/* MacbookAir9,1 (2020) */ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K 0x0280 ++/* MacbookPro16,1 (2019)*/ ++#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F 0x0340 ++ + #define BCM5974_DEVICE(prod) { \ + .match_flags = (USB_DEVICE_ID_MATCH_DEVICE | \ + USB_DEVICE_ID_MATCH_INT_CLASS | \ +@@ -147,6 +165,22 @@ static const struct usb_device_id bcm5974_table[] = { + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRING9_ANSI), + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRING9_ISO), + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRING9_JIS), ++ /* MacbookAir8,1 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K), ++ /* MacbookPro15,2 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132), ++ /* MacbookPro15,1 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680), ++ /* MacbookPro15,4 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213), ++ /* MacbookPro16,2 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K), ++ /* MacbookPro16,3 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223), ++ /* MacbookAir9,1 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K), ++ /* MacbookPro16,1 */ ++ BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F), + /* Terminating entry */ + {} + }; +@@ -483,6 +517,110 @@ static const struct bcm5974_config bcm5974_config_table[] = { + { SN_COORD, -203, 6803 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -6243, 6749 }, ++ { SN_COORD, -170, 7685 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -6243, 6749 }, ++ { SN_COORD, -170, 7685 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -7456, 7976 }, ++ { SN_COORD, -1768, 7685 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -6243, 6749 }, ++ { SN_COORD, -170, 7685 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -7823, 8329 }, ++ { SN_COORD, -370, 7925 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -6243, 6749 }, ++ { SN_COORD, -170, 7685 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -6243, 6749 }, ++ { SN_COORD, -170, 7685 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, ++ { ++ USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F, ++ 0, ++ 0, ++ HAS_INTEGRATED_BUTTON, ++ 0, sizeof(struct bt_data), ++ 0x83, DATAFORMAT(TYPE4), ++ { SN_PRESSURE, 0, 300 }, ++ { SN_WIDTH, 0, 2048 }, ++ { SN_COORD, -8916, 9918 }, ++ { SN_COORD, -1934, 9835 }, ++ { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } ++ }, + {} + }; + +diff --git a/drivers/pci/vgaarb.c b/drivers/pci/vgaarb.c +index 78748e8d2dba..2b2b558cebe6 100644 +--- a/drivers/pci/vgaarb.c ++++ b/drivers/pci/vgaarb.c +@@ -143,6 +143,7 @@ void vga_set_default_device(struct pci_dev *pdev) + pci_dev_put(vga_default); + vga_default = pci_dev_get(pdev); + } ++EXPORT_SYMBOL_GPL(vga_set_default_device); + + /** + * vga_remove_vgacon - deactivate VGA console +diff --git a/drivers/platform/x86/apple-gmux.c b/drivers/platform/x86/apple-gmux.c +index 1417e230edbd..e69785af8e1d 100644 +--- a/drivers/platform/x86/apple-gmux.c ++++ b/drivers/platform/x86/apple-gmux.c +@@ -21,6 +21,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -107,6 +108,10 @@ struct apple_gmux_config { + + # define MMIO_GMUX_MAX_BRIGHTNESS 0xffff + ++static bool force_igd; ++module_param(force_igd, bool, 0); ++MODULE_PARM_DESC(force_idg, "Switch gpu to igd on module load. Make sure that you have apple-set-os set up and the iGPU is in `lspci -s 00:02.0`. (default: false) (bool)"); ++ + static u8 gmux_pio_read8(struct apple_gmux_data *gmux_data, int port) + { + return inb(gmux_data->iostart + port); +@@ -945,6 +950,19 @@ static int gmux_probe(struct pnp_dev *pnp, const struct pnp_device_id *id) + gmux_enable_interrupts(gmux_data); + gmux_read_switch_state(gmux_data); + ++ if (force_igd) { ++ struct pci_dev *pdev; ++ ++ pdev = pci_get_domain_bus_and_slot(0, 0, PCI_DEVFN(2, 0)); ++ if (pdev) { ++ pr_info("Switching to IGD"); ++ gmux_switchto(VGA_SWITCHEROO_IGD); ++ vga_set_default_device(pdev); ++ } else { ++ pr_err("force_idg is true, but couldn't find iGPU at 00:02.0! Is apple-set-os working?"); ++ } ++ } ++ + /* + * Retina MacBook Pros cannot switch the panel's AUX separately + * and need eDP pre-calibration. They are distinguishable from +diff --git a/drivers/staging/Kconfig b/drivers/staging/Kconfig +index db4a392841b1..580df4ce4f9f 100644 +--- a/drivers/staging/Kconfig ++++ b/drivers/staging/Kconfig +@@ -66,4 +66,6 @@ source "drivers/staging/fieldbus/Kconfig" + + source "drivers/staging/vme_user/Kconfig" + ++source "drivers/staging/apple-bce/Kconfig" ++ + endif # STAGING +diff --git a/drivers/staging/Makefile b/drivers/staging/Makefile +index 5390879b5d1b..528be2d3b546 100644 +--- a/drivers/staging/Makefile ++++ b/drivers/staging/Makefile +@@ -22,3 +22,4 @@ obj-$(CONFIG_GREYBUS) += greybus/ + obj-$(CONFIG_BCM2835_VCHIQ) += vc04_services/ + obj-$(CONFIG_XIL_AXIS_FIFO) += axis-fifo/ + obj-$(CONFIG_FIELDBUS_DEV) += fieldbus/ ++obj-$(CONFIG_APPLE_BCE) += apple-bce/ +diff --git a/drivers/staging/apple-bce/Kconfig b/drivers/staging/apple-bce/Kconfig +new file mode 100644 +index 000000000000..fe92bc441e89 +--- /dev/null ++++ b/drivers/staging/apple-bce/Kconfig +@@ -0,0 +1,18 @@ ++config APPLE_BCE ++ tristate "Apple BCE driver (VHCI and Audio support)" ++ default m ++ depends on X86 ++ select SOUND ++ select SND ++ select SND_PCM ++ select SND_JACK ++ help ++ VHCI and audio support on Apple MacBooks with the T2 Chip. ++ This driver is divided in three components: ++ - BCE (Buffer Copy Engine): which establishes a basic communication ++ channel with the T2 chip. This component is required by the other two: ++ - VHCI (Virtual Host Controller Interface): Access to keyboard, mouse ++ and other system devices depend on this virtual USB host controller ++ - Audio: a driver for the T2 audio interface. ++ ++ If "M" is selected, the module will be called apple-bce.' +diff --git a/drivers/staging/apple-bce/Makefile b/drivers/staging/apple-bce/Makefile +new file mode 100644 +index 000000000000..8cfbd3f64af6 +--- /dev/null ++++ b/drivers/staging/apple-bce/Makefile +@@ -0,0 +1,28 @@ ++modname := apple-bce ++obj-$(CONFIG_APPLE_BCE) += $(modname).o ++ ++apple-bce-objs := apple_bce.o mailbox.o queue.o queue_dma.o vhci/vhci.o vhci/queue.o vhci/transfer.o audio/audio.o audio/protocol.o audio/protocol_bce.o audio/pcm.o ++ ++MY_CFLAGS += -DWITHOUT_NVME_PATCH ++#MY_CFLAGS += -g -DDEBUG ++ccflags-y += ${MY_CFLAGS} ++CC += ${MY_CFLAGS} ++ ++KVERSION := $(KERNELRELEASE) ++ifeq ($(origin KERNELRELEASE), undefined) ++KVERSION := $(shell uname -r) ++endif ++ ++KDIR := /lib/modules/$(KVERSION)/build ++PWD := $(shell pwd) ++ ++.PHONY: all ++ ++all: ++ $(MAKE) -C $(KDIR) M=$(PWD) modules ++ ++clean: ++ $(MAKE) -C $(KDIR) M=$(PWD) clean ++ ++install: ++ $(MAKE) -C $(KDIR) M=$(PWD) modules_install +diff --git a/drivers/staging/apple-bce/apple_bce.c b/drivers/staging/apple-bce/apple_bce.c +new file mode 100644 +index 000000000000..4fd2415d7028 +--- /dev/null ++++ b/drivers/staging/apple-bce/apple_bce.c +@@ -0,0 +1,445 @@ ++#include "apple_bce.h" ++#include ++#include ++#include "audio/audio.h" ++#include ++ ++static dev_t bce_chrdev; ++static struct class *bce_class; ++ ++struct apple_bce_device *global_bce; ++ ++static int bce_create_command_queues(struct apple_bce_device *bce); ++static void bce_free_command_queues(struct apple_bce_device *bce); ++static irqreturn_t bce_handle_mb_irq(int irq, void *dev); ++static irqreturn_t bce_handle_dma_irq(int irq, void *dev); ++static int bce_fw_version_handshake(struct apple_bce_device *bce); ++static int bce_register_command_queue(struct apple_bce_device *bce, struct bce_queue_memcfg *cfg, int is_sq); ++ ++static int apple_bce_probe(struct pci_dev *dev, const struct pci_device_id *id) ++{ ++ struct apple_bce_device *bce = NULL; ++ int status = 0; ++ int nvec; ++ ++ pr_info("apple-bce: capturing our device\n"); ++ ++ if (pci_enable_device(dev)) ++ return -ENODEV; ++ if (pci_request_regions(dev, "apple-bce")) { ++ status = -ENODEV; ++ goto fail; ++ } ++ pci_set_master(dev); ++ nvec = pci_alloc_irq_vectors(dev, 1, 8, PCI_IRQ_MSI); ++ if (nvec < 5) { ++ status = -EINVAL; ++ goto fail; ++ } ++ ++ bce = kzalloc(sizeof(struct apple_bce_device), GFP_KERNEL); ++ if (!bce) { ++ status = -ENOMEM; ++ goto fail; ++ } ++ ++ bce->pci = dev; ++ pci_set_drvdata(dev, bce); ++ ++ bce->devt = bce_chrdev; ++ bce->dev = device_create(bce_class, &dev->dev, bce->devt, NULL, "apple-bce"); ++ if (IS_ERR_OR_NULL(bce->dev)) { ++ status = PTR_ERR(bce_class); ++ goto fail; ++ } ++ ++ bce->reg_mem_mb = pci_iomap(dev, 4, 0); ++ bce->reg_mem_dma = pci_iomap(dev, 2, 0); ++ ++ if (IS_ERR_OR_NULL(bce->reg_mem_mb) || IS_ERR_OR_NULL(bce->reg_mem_dma)) { ++ dev_warn(&dev->dev, "apple-bce: Failed to pci_iomap required regions\n"); ++ goto fail; ++ } ++ ++ bce_mailbox_init(&bce->mbox, bce->reg_mem_mb); ++ bce_timestamp_init(&bce->timestamp, bce->reg_mem_mb); ++ ++ spin_lock_init(&bce->queues_lock); ++ ida_init(&bce->queue_ida); ++ ++ if ((status = pci_request_irq(dev, 0, bce_handle_mb_irq, NULL, dev, "bce_mbox"))) ++ goto fail; ++ if ((status = pci_request_irq(dev, 4, NULL, bce_handle_dma_irq, dev, "bce_dma"))) ++ goto fail_interrupt_0; ++ ++ if ((status = dma_set_mask_and_coherent(&dev->dev, DMA_BIT_MASK(37)))) { ++ dev_warn(&dev->dev, "dma: Setting mask failed\n"); ++ goto fail_interrupt; ++ } ++ ++ /* Gets the function 0's interface. This is needed because Apple only accepts DMA on our function if function 0 ++ is a bus master, so we need to work around this. */ ++ bce->pci0 = pci_get_slot(dev->bus, PCI_DEVFN(PCI_SLOT(dev->devfn), 0)); ++#ifndef WITHOUT_NVME_PATCH ++ if ((status = pci_enable_device_mem(bce->pci0))) { ++ dev_warn(&dev->dev, "apple-bce: failed to enable function 0\n"); ++ goto fail_dev0; ++ } ++#endif ++ pci_set_master(bce->pci0); ++ ++ bce_timestamp_start(&bce->timestamp, true); ++ ++ if ((status = bce_fw_version_handshake(bce))) ++ goto fail_ts; ++ pr_info("apple-bce: handshake done\n"); ++ ++ if ((status = bce_create_command_queues(bce))) { ++ pr_info("apple-bce: Creating command queues failed\n"); ++ goto fail_ts; ++ } ++ ++ global_bce = bce; ++ ++ bce_vhci_create(bce, &bce->vhci); ++ ++ return 0; ++ ++fail_ts: ++ bce_timestamp_stop(&bce->timestamp); ++#ifndef WITHOUT_NVME_PATCH ++ pci_disable_device(bce->pci0); ++fail_dev0: ++#endif ++ pci_dev_put(bce->pci0); ++fail_interrupt: ++ pci_free_irq(dev, 4, dev); ++fail_interrupt_0: ++ pci_free_irq(dev, 0, dev); ++fail: ++ if (bce && bce->dev) { ++ device_destroy(bce_class, bce->devt); ++ ++ if (!IS_ERR_OR_NULL(bce->reg_mem_mb)) ++ pci_iounmap(dev, bce->reg_mem_mb); ++ if (!IS_ERR_OR_NULL(bce->reg_mem_dma)) ++ pci_iounmap(dev, bce->reg_mem_dma); ++ ++ kfree(bce); ++ } ++ ++ pci_free_irq_vectors(dev); ++ pci_release_regions(dev); ++ pci_disable_device(dev); ++ ++ if (!status) ++ status = -EINVAL; ++ return status; ++} ++ ++static int bce_create_command_queues(struct apple_bce_device *bce) ++{ ++ int status; ++ struct bce_queue_memcfg *cfg; ++ ++ bce->cmd_cq = bce_alloc_cq(bce, 0, 0x20); ++ bce->cmd_cmdq = bce_alloc_cmdq(bce, 1, 0x20); ++ if (bce->cmd_cq == NULL || bce->cmd_cmdq == NULL) { ++ status = -ENOMEM; ++ goto err; ++ } ++ bce->queues[0] = (struct bce_queue *) bce->cmd_cq; ++ bce->queues[1] = (struct bce_queue *) bce->cmd_cmdq->sq; ++ ++ cfg = kzalloc(sizeof(struct bce_queue_memcfg), GFP_KERNEL); ++ if (!cfg) { ++ status = -ENOMEM; ++ goto err; ++ } ++ bce_get_cq_memcfg(bce->cmd_cq, cfg); ++ if ((status = bce_register_command_queue(bce, cfg, false))) ++ goto err; ++ bce_get_sq_memcfg(bce->cmd_cmdq->sq, bce->cmd_cq, cfg); ++ if ((status = bce_register_command_queue(bce, cfg, true))) ++ goto err; ++ kfree(cfg); ++ ++ return 0; ++ ++err: ++ if (bce->cmd_cq) ++ bce_free_cq(bce, bce->cmd_cq); ++ if (bce->cmd_cmdq) ++ bce_free_cmdq(bce, bce->cmd_cmdq); ++ return status; ++} ++ ++static void bce_free_command_queues(struct apple_bce_device *bce) ++{ ++ bce_free_cq(bce, bce->cmd_cq); ++ bce_free_cmdq(bce, bce->cmd_cmdq); ++ bce->cmd_cq = NULL; ++ bce->queues[0] = NULL; ++} ++ ++static irqreturn_t bce_handle_mb_irq(int irq, void *dev) ++{ ++ struct apple_bce_device *bce = pci_get_drvdata(dev); ++ bce_mailbox_handle_interrupt(&bce->mbox); ++ return IRQ_HANDLED; ++} ++ ++static irqreturn_t bce_handle_dma_irq(int irq, void *dev) ++{ ++ int i; ++ struct apple_bce_device *bce = pci_get_drvdata(dev); ++ spin_lock(&bce->queues_lock); ++ for (i = 0; i < BCE_MAX_QUEUE_COUNT; i++) ++ if (bce->queues[i] && bce->queues[i]->type == BCE_QUEUE_CQ) ++ bce_handle_cq_completions(bce, (struct bce_queue_cq *) bce->queues[i]); ++ spin_unlock(&bce->queues_lock); ++ return IRQ_HANDLED; ++} ++ ++static int bce_fw_version_handshake(struct apple_bce_device *bce) ++{ ++ u64 result; ++ int status; ++ ++ if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_SET_FW_PROTOCOL_VERSION, BC_PROTOCOL_VERSION), ++ &result))) ++ return status; ++ if (BCE_MB_TYPE(result) != BCE_MB_SET_FW_PROTOCOL_VERSION || ++ BCE_MB_VALUE(result) != BC_PROTOCOL_VERSION) { ++ pr_err("apple-bce: FW version handshake failed %x:%llx\n", BCE_MB_TYPE(result), BCE_MB_VALUE(result)); ++ return -EINVAL; ++ } ++ return 0; ++} ++ ++static int bce_register_command_queue(struct apple_bce_device *bce, struct bce_queue_memcfg *cfg, int is_sq) ++{ ++ int status; ++ int cmd_type; ++ u64 result; ++ // OS X uses an bidirectional direction, but that's not really needed ++ dma_addr_t a = dma_map_single(&bce->pci->dev, cfg, sizeof(struct bce_queue_memcfg), DMA_TO_DEVICE); ++ if (dma_mapping_error(&bce->pci->dev, a)) ++ return -ENOMEM; ++ cmd_type = is_sq ? BCE_MB_REGISTER_COMMAND_SQ : BCE_MB_REGISTER_COMMAND_CQ; ++ status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(cmd_type, a), &result); ++ dma_unmap_single(&bce->pci->dev, a, sizeof(struct bce_queue_memcfg), DMA_TO_DEVICE); ++ if (status) ++ return status; ++ if (BCE_MB_TYPE(result) != BCE_MB_REGISTER_COMMAND_QUEUE_REPLY) ++ return -EINVAL; ++ return 0; ++} ++ ++static void apple_bce_remove(struct pci_dev *dev) ++{ ++ struct apple_bce_device *bce = pci_get_drvdata(dev); ++ bce->is_being_removed = true; ++ ++ bce_vhci_destroy(&bce->vhci); ++ ++ bce_timestamp_stop(&bce->timestamp); ++#ifndef WITHOUT_NVME_PATCH ++ pci_disable_device(bce->pci0); ++#endif ++ pci_dev_put(bce->pci0); ++ pci_free_irq(dev, 0, dev); ++ pci_free_irq(dev, 4, dev); ++ bce_free_command_queues(bce); ++ pci_iounmap(dev, bce->reg_mem_mb); ++ pci_iounmap(dev, bce->reg_mem_dma); ++ device_destroy(bce_class, bce->devt); ++ pci_free_irq_vectors(dev); ++ pci_release_regions(dev); ++ pci_disable_device(dev); ++ kfree(bce); ++} ++ ++static int bce_save_state_and_sleep(struct apple_bce_device *bce) ++{ ++ int attempt, status = 0; ++ u64 resp; ++ dma_addr_t dma_addr; ++ void *dma_ptr = NULL; ++ size_t size = max(PAGE_SIZE, 4096UL); ++ ++ for (attempt = 0; attempt < 5; ++attempt) { ++ pr_debug("apple-bce: suspend: attempt %i, buffer size %li\n", attempt, size); ++ dma_ptr = dma_alloc_coherent(&bce->pci->dev, size, &dma_addr, GFP_KERNEL); ++ if (!dma_ptr) { ++ pr_err("apple-bce: suspend failed (data alloc failed)\n"); ++ break; ++ } ++ BUG_ON((dma_addr % 4096) != 0); ++ status = bce_mailbox_send(&bce->mbox, ++ BCE_MB_MSG(BCE_MB_SAVE_STATE_AND_SLEEP, (dma_addr & ~(4096LLU - 1)) | (size / 4096)), &resp); ++ if (status) { ++ pr_err("apple-bce: suspend failed (mailbox send)\n"); ++ break; ++ } ++ if (BCE_MB_TYPE(resp) == BCE_MB_SAVE_RESTORE_STATE_COMPLETE) { ++ bce->saved_data_dma_addr = dma_addr; ++ bce->saved_data_dma_ptr = dma_ptr; ++ bce->saved_data_dma_size = size; ++ return 0; ++ } else if (BCE_MB_TYPE(resp) == BCE_MB_SAVE_STATE_AND_SLEEP_FAILURE) { ++ dma_free_coherent(&bce->pci->dev, size, dma_ptr, dma_addr); ++ /* The 0x10ff magic value was extracted from Apple's driver */ ++ size = (BCE_MB_VALUE(resp) + 0x10ff) & ~(4096LLU - 1); ++ pr_debug("apple-bce: suspend: device requested a larger buffer (%li)\n", size); ++ continue; ++ } else { ++ pr_err("apple-bce: suspend failed (invalid device response)\n"); ++ status = -EINVAL; ++ break; ++ } ++ } ++ if (dma_ptr) ++ dma_free_coherent(&bce->pci->dev, size, dma_ptr, dma_addr); ++ if (!status) ++ return bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_SLEEP_NO_STATE, 0), &resp); ++ return status; ++} ++ ++static int bce_restore_state_and_wake(struct apple_bce_device *bce) ++{ ++ int status; ++ u64 resp; ++ if (!bce->saved_data_dma_ptr) { ++ if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_RESTORE_NO_STATE, 0), &resp))) { ++ pr_err("apple-bce: resume with no state failed (mailbox send)\n"); ++ return status; ++ } ++ if (BCE_MB_TYPE(resp) != BCE_MB_RESTORE_NO_STATE) { ++ pr_err("apple-bce: resume with no state failed (invalid device response)\n"); ++ return -EINVAL; ++ } ++ return 0; ++ } ++ ++ if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_RESTORE_STATE_AND_WAKE, ++ (bce->saved_data_dma_addr & ~(4096LLU - 1)) | (bce->saved_data_dma_size / 4096)), &resp))) { ++ pr_err("apple-bce: resume with state failed (mailbox send)\n"); ++ goto finish_with_state; ++ } ++ if (BCE_MB_TYPE(resp) != BCE_MB_SAVE_RESTORE_STATE_COMPLETE) { ++ pr_err("apple-bce: resume with state failed (invalid device response)\n"); ++ status = -EINVAL; ++ goto finish_with_state; ++ } ++ ++finish_with_state: ++ dma_free_coherent(&bce->pci->dev, bce->saved_data_dma_size, bce->saved_data_dma_ptr, bce->saved_data_dma_addr); ++ bce->saved_data_dma_ptr = NULL; ++ return status; ++} ++ ++static int apple_bce_suspend(struct device *dev) ++{ ++ struct apple_bce_device *bce = pci_get_drvdata(to_pci_dev(dev)); ++ int status; ++ ++ bce_timestamp_stop(&bce->timestamp); ++ ++ if ((status = bce_save_state_and_sleep(bce))) ++ return status; ++ ++ return 0; ++} ++ ++static int apple_bce_resume(struct device *dev) ++{ ++ struct apple_bce_device *bce = pci_get_drvdata(to_pci_dev(dev)); ++ int status; ++ ++ pci_set_master(bce->pci); ++ pci_set_master(bce->pci0); ++ ++ if ((status = bce_restore_state_and_wake(bce))) ++ return status; ++ ++ bce_timestamp_start(&bce->timestamp, false); ++ ++ return 0; ++} ++ ++static struct pci_device_id apple_bce_ids[ ] = { ++ { PCI_DEVICE(PCI_VENDOR_ID_APPLE, 0x1801) }, ++ { 0, }, ++}; ++ ++MODULE_DEVICE_TABLE(pci, apple_bce_ids); ++ ++struct dev_pm_ops apple_bce_pci_driver_pm = { ++ .suspend = apple_bce_suspend, ++ .resume = apple_bce_resume ++}; ++struct pci_driver apple_bce_pci_driver = { ++ .name = "apple-bce", ++ .id_table = apple_bce_ids, ++ .probe = apple_bce_probe, ++ .remove = apple_bce_remove, ++ .driver = { ++ .pm = &apple_bce_pci_driver_pm ++ } ++}; ++ ++ ++static int __init apple_bce_module_init(void) ++{ ++ int result; ++ if ((result = alloc_chrdev_region(&bce_chrdev, 0, 1, "apple-bce"))) ++ goto fail_chrdev; ++#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) ++ bce_class = class_create(THIS_MODULE, "apple-bce"); ++#else ++ bce_class = class_create("apple-bce"); ++#endif ++ if (IS_ERR(bce_class)) { ++ result = PTR_ERR(bce_class); ++ goto fail_class; ++ } ++ if ((result = bce_vhci_module_init())) { ++ pr_err("apple-bce: bce-vhci init failed"); ++ goto fail_class; ++ } ++ ++ result = pci_register_driver(&apple_bce_pci_driver); ++ if (result) ++ goto fail_drv; ++ ++ aaudio_module_init(); ++ ++ return 0; ++ ++fail_drv: ++ pci_unregister_driver(&apple_bce_pci_driver); ++fail_class: ++ class_destroy(bce_class); ++fail_chrdev: ++ unregister_chrdev_region(bce_chrdev, 1); ++ if (!result) ++ result = -EINVAL; ++ return result; ++} ++static void __exit apple_bce_module_exit(void) ++{ ++ pci_unregister_driver(&apple_bce_pci_driver); ++ ++ aaudio_module_exit(); ++ bce_vhci_module_exit(); ++ class_destroy(bce_class); ++ unregister_chrdev_region(bce_chrdev, 1); ++} ++ ++MODULE_LICENSE("GPL"); ++MODULE_AUTHOR("MrARM"); ++MODULE_DESCRIPTION("Apple BCE Driver"); ++MODULE_VERSION("0.01"); ++module_init(apple_bce_module_init); ++module_exit(apple_bce_module_exit); +diff --git a/drivers/staging/apple-bce/apple_bce.h b/drivers/staging/apple-bce/apple_bce.h +new file mode 100644 +index 000000000000..f13ab8d5742e +--- /dev/null ++++ b/drivers/staging/apple-bce/apple_bce.h +@@ -0,0 +1,38 @@ ++#pragma once ++ ++#include ++#include ++#include "mailbox.h" ++#include "queue.h" ++#include "vhci/vhci.h" ++ ++#define BC_PROTOCOL_VERSION 0x20001 ++#define BCE_MAX_QUEUE_COUNT 0x100 ++ ++#define BCE_QUEUE_USER_MIN 2 ++#define BCE_QUEUE_USER_MAX (BCE_MAX_QUEUE_COUNT - 1) ++ ++struct apple_bce_device { ++ struct pci_dev *pci, *pci0; ++ dev_t devt; ++ struct device *dev; ++ void __iomem *reg_mem_mb; ++ void __iomem *reg_mem_dma; ++ struct bce_mailbox mbox; ++ struct bce_timestamp timestamp; ++ struct bce_queue *queues[BCE_MAX_QUEUE_COUNT]; ++ struct spinlock queues_lock; ++ struct ida queue_ida; ++ struct bce_queue_cq *cmd_cq; ++ struct bce_queue_cmdq *cmd_cmdq; ++ struct bce_queue_sq *int_sq_list[BCE_MAX_QUEUE_COUNT]; ++ bool is_being_removed; ++ ++ dma_addr_t saved_data_dma_addr; ++ void *saved_data_dma_ptr; ++ size_t saved_data_dma_size; ++ ++ struct bce_vhci vhci; ++}; ++ ++extern struct apple_bce_device *global_bce; +\ No newline at end of file +diff --git a/drivers/staging/apple-bce/audio/audio.c b/drivers/staging/apple-bce/audio/audio.c +new file mode 100644 +index 000000000000..bd16ddd16c1d +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/audio.c +@@ -0,0 +1,711 @@ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include "audio.h" ++#include "pcm.h" ++#include ++ ++static int aaudio_alsa_index = SNDRV_DEFAULT_IDX1; ++static char *aaudio_alsa_id = SNDRV_DEFAULT_STR1; ++ ++static dev_t aaudio_chrdev; ++static struct class *aaudio_class; ++ ++static int aaudio_init_cmd(struct aaudio_device *a); ++static int aaudio_init_bs(struct aaudio_device *a); ++static void aaudio_init_dev(struct aaudio_device *a, aaudio_device_id_t dev_id); ++static void aaudio_free_dev(struct aaudio_subdevice *sdev); ++ ++static int aaudio_probe(struct pci_dev *dev, const struct pci_device_id *id) ++{ ++ struct aaudio_device *aaudio = NULL; ++ struct aaudio_subdevice *sdev = NULL; ++ int status = 0; ++ u32 cfg; ++ ++ pr_info("aaudio: capturing our device\n"); ++ ++ if (pci_enable_device(dev)) ++ return -ENODEV; ++ if (pci_request_regions(dev, "aaudio")) { ++ status = -ENODEV; ++ goto fail; ++ } ++ pci_set_master(dev); ++ ++ aaudio = kzalloc(sizeof(struct aaudio_device), GFP_KERNEL); ++ if (!aaudio) { ++ status = -ENOMEM; ++ goto fail; ++ } ++ ++ aaudio->bce = global_bce; ++ if (!aaudio->bce) { ++ dev_warn(&dev->dev, "aaudio: No BCE available\n"); ++ status = -EINVAL; ++ goto fail; ++ } ++ ++ aaudio->pci = dev; ++ pci_set_drvdata(dev, aaudio); ++ ++ aaudio->devt = aaudio_chrdev; ++ aaudio->dev = device_create(aaudio_class, &dev->dev, aaudio->devt, NULL, "aaudio"); ++ if (IS_ERR_OR_NULL(aaudio->dev)) { ++ status = PTR_ERR(aaudio_class); ++ goto fail; ++ } ++ device_link_add(aaudio->dev, aaudio->bce->dev, DL_FLAG_PM_RUNTIME | DL_FLAG_AUTOREMOVE_CONSUMER); ++ ++ init_completion(&aaudio->remote_alive); ++ INIT_LIST_HEAD(&aaudio->subdevice_list); ++ ++ /* Init: set an unknown flag in the bitset */ ++ if (pci_read_config_dword(dev, 4, &cfg)) ++ dev_warn(&dev->dev, "aaudio: pci_read_config_dword fail\n"); ++ if (pci_write_config_dword(dev, 4, cfg | 6u)) ++ dev_warn(&dev->dev, "aaudio: pci_write_config_dword fail\n"); ++ ++ dev_info(aaudio->dev, "aaudio: bs len = %llx\n", pci_resource_len(dev, 0)); ++ aaudio->reg_mem_bs_dma = pci_resource_start(dev, 0); ++ aaudio->reg_mem_bs = pci_iomap(dev, 0, 0); ++ aaudio->reg_mem_cfg = pci_iomap(dev, 4, 0); ++ ++ aaudio->reg_mem_gpr = (u32 __iomem *) ((u8 __iomem *) aaudio->reg_mem_cfg + 0xC000); ++ ++ if (IS_ERR_OR_NULL(aaudio->reg_mem_bs) || IS_ERR_OR_NULL(aaudio->reg_mem_cfg)) { ++ dev_warn(&dev->dev, "aaudio: Failed to pci_iomap required regions\n"); ++ goto fail; ++ } ++ ++ if (aaudio_bce_init(aaudio)) { ++ dev_warn(&dev->dev, "aaudio: Failed to init BCE command transport\n"); ++ goto fail; ++ } ++ ++ if (snd_card_new(aaudio->dev, aaudio_alsa_index, aaudio_alsa_id, THIS_MODULE, 0, &aaudio->card)) { ++ dev_err(&dev->dev, "aaudio: Failed to create ALSA card\n"); ++ goto fail; ++ } ++ ++ strcpy(aaudio->card->shortname, "Apple T2 Audio"); ++ strcpy(aaudio->card->longname, "Apple T2 Audio"); ++ strcpy(aaudio->card->mixername, "Apple T2 Audio"); ++ /* Dynamic alsa ids start at 100 */ ++ aaudio->next_alsa_id = 100; ++ ++ if (aaudio_init_cmd(aaudio)) { ++ dev_err(&dev->dev, "aaudio: Failed to initialize over BCE\n"); ++ goto fail_snd; ++ } ++ ++ if (aaudio_init_bs(aaudio)) { ++ dev_err(&dev->dev, "aaudio: Failed to initialize BufferStruct\n"); ++ goto fail_snd; ++ } ++ ++ if ((status = aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_ON))) { ++ dev_err(&dev->dev, "Failed to set remote access\n"); ++ return status; ++ } ++ ++ if (snd_card_register(aaudio->card)) { ++ dev_err(&dev->dev, "aaudio: Failed to register ALSA sound device\n"); ++ goto fail_snd; ++ } ++ ++ list_for_each_entry(sdev, &aaudio->subdevice_list, list) { ++ struct aaudio_buffer_struct_device *dev = &aaudio->bs->devices[sdev->buf_id]; ++ ++ if (sdev->out_stream_cnt == 1 && !strcmp(dev->name, "Speaker")) { ++ struct snd_pcm_hardware *hw = sdev->out_streams[0].alsa_hw_desc; ++ ++ snprintf(aaudio->card->driver, sizeof(aaudio->card->driver) / sizeof(char), "AppleT2x%d", hw->channels_min); ++ } ++ } ++ ++ return 0; ++ ++fail_snd: ++ snd_card_free(aaudio->card); ++fail: ++ if (aaudio && aaudio->dev) ++ device_destroy(aaudio_class, aaudio->devt); ++ kfree(aaudio); ++ ++ if (!IS_ERR_OR_NULL(aaudio->reg_mem_bs)) ++ pci_iounmap(dev, aaudio->reg_mem_bs); ++ if (!IS_ERR_OR_NULL(aaudio->reg_mem_cfg)) ++ pci_iounmap(dev, aaudio->reg_mem_cfg); ++ ++ pci_release_regions(dev); ++ pci_disable_device(dev); ++ ++ if (!status) ++ status = -EINVAL; ++ return status; ++} ++ ++ ++ ++static void aaudio_remove(struct pci_dev *dev) ++{ ++ struct aaudio_subdevice *sdev; ++ struct aaudio_device *aaudio = pci_get_drvdata(dev); ++ ++ snd_card_free(aaudio->card); ++ while (!list_empty(&aaudio->subdevice_list)) { ++ sdev = list_first_entry(&aaudio->subdevice_list, struct aaudio_subdevice, list); ++ list_del(&sdev->list); ++ aaudio_free_dev(sdev); ++ } ++ pci_iounmap(dev, aaudio->reg_mem_bs); ++ pci_iounmap(dev, aaudio->reg_mem_cfg); ++ device_destroy(aaudio_class, aaudio->devt); ++ pci_free_irq_vectors(dev); ++ pci_release_regions(dev); ++ pci_disable_device(dev); ++ kfree(aaudio); ++} ++ ++static int aaudio_suspend(struct device *dev) ++{ ++ struct aaudio_device *aaudio = pci_get_drvdata(to_pci_dev(dev)); ++ ++ if (aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_OFF)) ++ dev_warn(aaudio->dev, "Failed to reset remote access\n"); ++ ++ pci_disable_device(aaudio->pci); ++ return 0; ++} ++ ++static int aaudio_resume(struct device *dev) ++{ ++ int status; ++ struct aaudio_device *aaudio = pci_get_drvdata(to_pci_dev(dev)); ++ ++ if ((status = pci_enable_device(aaudio->pci))) ++ return status; ++ pci_set_master(aaudio->pci); ++ ++ if ((status = aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_ON))) { ++ dev_err(aaudio->dev, "Failed to set remote access\n"); ++ return status; ++ } ++ ++ return 0; ++} ++ ++static int aaudio_init_cmd(struct aaudio_device *a) ++{ ++ int status; ++ struct aaudio_send_ctx sctx; ++ struct aaudio_msg buf; ++ u64 dev_cnt, dev_i; ++ aaudio_device_id_t *dev_l; ++ ++ if ((status = aaudio_send(a, &sctx, 500, ++ aaudio_msg_write_alive_notification, 1, 3))) { ++ dev_err(a->dev, "Sending alive notification failed\n"); ++ return status; ++ } ++ ++ if (wait_for_completion_timeout(&a->remote_alive, msecs_to_jiffies(500)) == 0) { ++ dev_err(a->dev, "Timed out waiting for remote\n"); ++ return -ETIMEDOUT; ++ } ++ dev_info(a->dev, "Continuing init\n"); ++ ++ buf = aaudio_reply_alloc(); ++ if ((status = aaudio_cmd_get_device_list(a, &buf, &dev_l, &dev_cnt))) { ++ dev_err(a->dev, "Failed to get device list\n"); ++ aaudio_reply_free(&buf); ++ return status; ++ } ++ for (dev_i = 0; dev_i < dev_cnt; ++dev_i) ++ aaudio_init_dev(a, dev_l[dev_i]); ++ aaudio_reply_free(&buf); ++ ++ return 0; ++} ++ ++static void aaudio_init_stream_info(struct aaudio_subdevice *sdev, struct aaudio_stream *strm); ++static void aaudio_handle_jack_connection_change(struct aaudio_subdevice *sdev); ++ ++static void aaudio_init_dev(struct aaudio_device *a, aaudio_device_id_t dev_id) ++{ ++ struct aaudio_subdevice *sdev; ++ struct aaudio_msg buf = aaudio_reply_alloc(); ++ u64 uid_len, stream_cnt, i; ++ aaudio_object_id_t *stream_list; ++ char *uid; ++ ++ sdev = kzalloc(sizeof(struct aaudio_subdevice), GFP_KERNEL); ++ ++ if (aaudio_cmd_get_property(a, &buf, dev_id, dev_id, AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_UID, 0), ++ NULL, 0, (void **) &uid, &uid_len) || uid_len > AAUDIO_DEVICE_MAX_UID_LEN) { ++ dev_err(a->dev, "Failed to get device uid for device %llx\n", dev_id); ++ goto fail; ++ } ++ dev_info(a->dev, "Remote device %llx %.*s\n", dev_id, (int) uid_len, uid); ++ ++ sdev->a = a; ++ INIT_LIST_HEAD(&sdev->list); ++ sdev->dev_id = dev_id; ++ sdev->buf_id = AAUDIO_BUFFER_ID_NONE; ++ strncpy(sdev->uid, uid, uid_len); ++ sdev->uid[uid_len + 1] = '\0'; ++ ++ if (aaudio_cmd_get_primitive_property(a, dev_id, dev_id, ++ AAUDIO_PROP(AAUDIO_PROP_SCOPE_INPUT, AAUDIO_PROP_LATENCY, 0), NULL, 0, &sdev->in_latency, sizeof(u32))) ++ dev_warn(a->dev, "Failed to query device input latency\n"); ++ if (aaudio_cmd_get_primitive_property(a, dev_id, dev_id, ++ AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_LATENCY, 0), NULL, 0, &sdev->out_latency, sizeof(u32))) ++ dev_warn(a->dev, "Failed to query device output latency\n"); ++ ++ if (aaudio_cmd_get_input_stream_list(a, &buf, dev_id, &stream_list, &stream_cnt)) { ++ dev_err(a->dev, "Failed to get input stream list for device %llx\n", dev_id); ++ goto fail; ++ } ++ if (stream_cnt > AAUDIO_DEIVCE_MAX_INPUT_STREAMS) { ++ dev_warn(a->dev, "Device %s input stream count %llu is larger than the supported count of %u\n", ++ sdev->uid, stream_cnt, AAUDIO_DEIVCE_MAX_INPUT_STREAMS); ++ stream_cnt = AAUDIO_DEIVCE_MAX_INPUT_STREAMS; ++ } ++ sdev->in_stream_cnt = stream_cnt; ++ for (i = 0; i < stream_cnt; i++) { ++ sdev->in_streams[i].id = stream_list[i]; ++ sdev->in_streams[i].buffer_cnt = 0; ++ aaudio_init_stream_info(sdev, &sdev->in_streams[i]); ++ sdev->in_streams[i].latency += sdev->in_latency; ++ } ++ ++ if (aaudio_cmd_get_output_stream_list(a, &buf, dev_id, &stream_list, &stream_cnt)) { ++ dev_err(a->dev, "Failed to get output stream list for device %llx\n", dev_id); ++ goto fail; ++ } ++ if (stream_cnt > AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS) { ++ dev_warn(a->dev, "Device %s input stream count %llu is larger than the supported count of %u\n", ++ sdev->uid, stream_cnt, AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS); ++ stream_cnt = AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS; ++ } ++ sdev->out_stream_cnt = stream_cnt; ++ for (i = 0; i < stream_cnt; i++) { ++ sdev->out_streams[i].id = stream_list[i]; ++ sdev->out_streams[i].buffer_cnt = 0; ++ aaudio_init_stream_info(sdev, &sdev->out_streams[i]); ++ sdev->out_streams[i].latency += sdev->in_latency; ++ } ++ ++ if (sdev->is_pcm) ++ aaudio_create_pcm(sdev); ++ /* Headphone Jack status */ ++ if (!strcmp(sdev->uid, "Codec Output")) { ++ if (snd_jack_new(a->card, sdev->uid, SND_JACK_HEADPHONE, &sdev->jack, true, false)) ++ dev_warn(a->dev, "Failed to create an attached jack for %s\n", sdev->uid); ++ aaudio_cmd_property_listener(a, sdev->dev_id, sdev->dev_id, ++ AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_JACK_PLUGGED, 0)); ++ aaudio_handle_jack_connection_change(sdev); ++ } ++ ++ aaudio_reply_free(&buf); ++ ++ list_add_tail(&sdev->list, &a->subdevice_list); ++ return; ++ ++fail: ++ aaudio_reply_free(&buf); ++ kfree(sdev); ++} ++ ++static void aaudio_init_stream_info(struct aaudio_subdevice *sdev, struct aaudio_stream *strm) ++{ ++ if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, strm->id, ++ AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_PHYS_FORMAT, 0), NULL, 0, ++ &strm->desc, sizeof(strm->desc))) ++ dev_warn(sdev->a->dev, "Failed to query stream descriptor\n"); ++ if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, strm->id, ++ AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_LATENCY, 0), NULL, 0, &strm->latency, sizeof(u32))) ++ dev_warn(sdev->a->dev, "Failed to query stream latency\n"); ++ if (strm->desc.format_id == AAUDIO_FORMAT_LPCM) ++ sdev->is_pcm = true; ++} ++ ++static void aaudio_free_dev(struct aaudio_subdevice *sdev) ++{ ++ size_t i; ++ for (i = 0; i < sdev->in_stream_cnt; i++) { ++ if (sdev->in_streams[i].alsa_hw_desc) ++ kfree(sdev->in_streams[i].alsa_hw_desc); ++ if (sdev->in_streams[i].buffers) ++ kfree(sdev->in_streams[i].buffers); ++ } ++ for (i = 0; i < sdev->out_stream_cnt; i++) { ++ if (sdev->out_streams[i].alsa_hw_desc) ++ kfree(sdev->out_streams[i].alsa_hw_desc); ++ if (sdev->out_streams[i].buffers) ++ kfree(sdev->out_streams[i].buffers); ++ } ++ kfree(sdev); ++} ++ ++static struct aaudio_subdevice *aaudio_find_dev_by_dev_id(struct aaudio_device *a, aaudio_device_id_t dev_id) ++{ ++ struct aaudio_subdevice *sdev; ++ list_for_each_entry(sdev, &a->subdevice_list, list) { ++ if (dev_id == sdev->dev_id) ++ return sdev; ++ } ++ return NULL; ++} ++ ++static struct aaudio_subdevice *aaudio_find_dev_by_uid(struct aaudio_device *a, const char *uid) ++{ ++ struct aaudio_subdevice *sdev; ++ list_for_each_entry(sdev, &a->subdevice_list, list) { ++ if (!strcmp(uid, sdev->uid)) ++ return sdev; ++ } ++ return NULL; ++} ++ ++static void aaudio_init_bs_stream(struct aaudio_device *a, struct aaudio_stream *strm, ++ struct aaudio_buffer_struct_stream *bs_strm); ++static void aaudio_init_bs_stream_host(struct aaudio_device *a, struct aaudio_stream *strm, ++ struct aaudio_buffer_struct_stream *bs_strm); ++ ++static int aaudio_init_bs(struct aaudio_device *a) ++{ ++ int i, j; ++ struct aaudio_buffer_struct_device *dev; ++ struct aaudio_subdevice *sdev; ++ u32 ver, sig, bs_base; ++ ++ ver = ioread32(&a->reg_mem_gpr[0]); ++ if (ver < 3) { ++ dev_err(a->dev, "aaudio: Bad GPR version (%u)", ver); ++ return -EINVAL; ++ } ++ sig = ioread32(&a->reg_mem_gpr[1]); ++ if (sig != AAUDIO_SIG) { ++ dev_err(a->dev, "aaudio: Bad GPR sig (%x)", sig); ++ return -EINVAL; ++ } ++ bs_base = ioread32(&a->reg_mem_gpr[2]); ++ a->bs = (struct aaudio_buffer_struct *) ((u8 *) a->reg_mem_bs + bs_base); ++ if (a->bs->signature != AAUDIO_SIG) { ++ dev_err(a->dev, "aaudio: Bad BufferStruct sig (%x)", a->bs->signature); ++ return -EINVAL; ++ } ++ dev_info(a->dev, "aaudio: BufferStruct ver = %i\n", a->bs->version); ++ dev_info(a->dev, "aaudio: Num devices = %i\n", a->bs->num_devices); ++ for (i = 0; i < a->bs->num_devices; i++) { ++ dev = &a->bs->devices[i]; ++ dev_info(a->dev, "aaudio: Device %i %s\n", i, dev->name); ++ ++ sdev = aaudio_find_dev_by_uid(a, dev->name); ++ if (!sdev) { ++ dev_err(a->dev, "aaudio: Subdevice not found for BufferStruct device %s\n", dev->name); ++ continue; ++ } ++ sdev->buf_id = (u8) i; ++ dev->num_input_streams = 0; ++ for (j = 0; j < dev->num_output_streams; j++) { ++ dev_info(a->dev, "aaudio: Device %i Stream %i: Output; Buffer Count = %i\n", i, j, ++ dev->output_streams[j].num_buffers); ++ if (j < sdev->out_stream_cnt) ++ aaudio_init_bs_stream(a, &sdev->out_streams[j], &dev->output_streams[j]); ++ } ++ } ++ ++ list_for_each_entry(sdev, &a->subdevice_list, list) { ++ if (sdev->buf_id != AAUDIO_BUFFER_ID_NONE) ++ continue; ++ sdev->buf_id = i; ++ dev_info(a->dev, "aaudio: Created device %i %s\n", i, sdev->uid); ++ strcpy(a->bs->devices[i].name, sdev->uid); ++ a->bs->devices[i].num_input_streams = 0; ++ a->bs->devices[i].num_output_streams = 0; ++ a->bs->num_devices = ++i; ++ } ++ list_for_each_entry(sdev, &a->subdevice_list, list) { ++ if (sdev->in_stream_cnt == 1) { ++ dev_info(a->dev, "aaudio: Device %i Host Stream; Input\n", sdev->buf_id); ++ aaudio_init_bs_stream_host(a, &sdev->in_streams[0], &a->bs->devices[sdev->buf_id].input_streams[0]); ++ a->bs->devices[sdev->buf_id].num_input_streams = 1; ++ wmb(); ++ ++ if (aaudio_cmd_set_input_stream_address_ranges(a, sdev->dev_id)) { ++ dev_err(a->dev, "aaudio: Failed to set input stream address ranges\n"); ++ } ++ } ++ } ++ ++ return 0; ++} ++ ++static void aaudio_init_bs_stream(struct aaudio_device *a, struct aaudio_stream *strm, ++ struct aaudio_buffer_struct_stream *bs_strm) ++{ ++ size_t i; ++ strm->buffer_cnt = bs_strm->num_buffers; ++ if (bs_strm->num_buffers > AAUDIO_DEIVCE_MAX_BUFFER_COUNT) { ++ dev_warn(a->dev, "BufferStruct buffer count %u exceeds driver limit of %u\n", bs_strm->num_buffers, ++ AAUDIO_DEIVCE_MAX_BUFFER_COUNT); ++ strm->buffer_cnt = AAUDIO_DEIVCE_MAX_BUFFER_COUNT; ++ } ++ if (!strm->buffer_cnt) ++ return; ++ strm->buffers = kmalloc_array(strm->buffer_cnt, sizeof(struct aaudio_dma_buf), GFP_KERNEL); ++ if (!strm->buffers) { ++ dev_err(a->dev, "Buffer list allocation failed\n"); ++ return; ++ } ++ for (i = 0; i < strm->buffer_cnt; i++) { ++ strm->buffers[i].dma_addr = a->reg_mem_bs_dma + (dma_addr_t) bs_strm->buffers[i].address; ++ strm->buffers[i].ptr = a->reg_mem_bs + bs_strm->buffers[i].address; ++ strm->buffers[i].size = bs_strm->buffers[i].size; ++ } ++ ++ if (strm->buffer_cnt == 1) { ++ strm->alsa_hw_desc = kmalloc(sizeof(struct snd_pcm_hardware), GFP_KERNEL); ++ if (aaudio_create_hw_info(&strm->desc, strm->alsa_hw_desc, strm->buffers[0].size)) { ++ kfree(strm->alsa_hw_desc); ++ strm->alsa_hw_desc = NULL; ++ } ++ } ++} ++ ++static void aaudio_init_bs_stream_host(struct aaudio_device *a, struct aaudio_stream *strm, ++ struct aaudio_buffer_struct_stream *bs_strm) ++{ ++ size_t size; ++ dma_addr_t dma_addr; ++ void *dma_ptr; ++ size = strm->desc.bytes_per_packet * 16640; ++ dma_ptr = dma_alloc_coherent(&a->pci->dev, size, &dma_addr, GFP_KERNEL); ++ if (!dma_ptr) { ++ dev_err(a->dev, "dma_alloc_coherent failed\n"); ++ return; ++ } ++ bs_strm->buffers[0].address = dma_addr; ++ bs_strm->buffers[0].size = size; ++ bs_strm->num_buffers = 1; ++ ++ memset(dma_ptr, 0, size); ++ ++ strm->buffer_cnt = 1; ++ strm->buffers = kmalloc_array(strm->buffer_cnt, sizeof(struct aaudio_dma_buf), GFP_KERNEL); ++ if (!strm->buffers) { ++ dev_err(a->dev, "Buffer list allocation failed\n"); ++ return; ++ } ++ strm->buffers[0].dma_addr = dma_addr; ++ strm->buffers[0].ptr = dma_ptr; ++ strm->buffers[0].size = size; ++ ++ strm->alsa_hw_desc = kmalloc(sizeof(struct snd_pcm_hardware), GFP_KERNEL); ++ if (aaudio_create_hw_info(&strm->desc, strm->alsa_hw_desc, strm->buffers[0].size)) { ++ kfree(strm->alsa_hw_desc); ++ strm->alsa_hw_desc = NULL; ++ } ++} ++ ++static void aaudio_handle_prop_change(struct aaudio_device *a, struct aaudio_msg *msg); ++ ++void aaudio_handle_notification(struct aaudio_device *a, struct aaudio_msg *msg) ++{ ++ struct aaudio_send_ctx sctx; ++ struct aaudio_msg_base base; ++ if (aaudio_msg_read_base(msg, &base)) ++ return; ++ switch (base.msg) { ++ case AAUDIO_MSG_NOTIFICATION_BOOT: ++ dev_info(a->dev, "Received boot notification from remote\n"); ++ ++ /* Resend the alive notify */ ++ if (aaudio_send(a, &sctx, 500, ++ aaudio_msg_write_alive_notification, 1, 3)) { ++ pr_err("Sending alive notification failed\n"); ++ } ++ break; ++ case AAUDIO_MSG_NOTIFICATION_ALIVE: ++ dev_info(a->dev, "Received alive notification from remote\n"); ++ complete_all(&a->remote_alive); ++ break; ++ case AAUDIO_MSG_PROPERTY_CHANGED: ++ aaudio_handle_prop_change(a, msg); ++ break; ++ default: ++ dev_info(a->dev, "Unhandled notification %i", base.msg); ++ break; ++ } ++} ++ ++struct aaudio_prop_change_work_struct { ++ struct work_struct ws; ++ struct aaudio_device *a; ++ aaudio_device_id_t dev; ++ aaudio_object_id_t obj; ++ struct aaudio_prop_addr prop; ++}; ++ ++static void aaudio_handle_jack_connection_change(struct aaudio_subdevice *sdev) ++{ ++ u32 plugged; ++ if (!sdev->jack) ++ return; ++ /* NOTE: Apple made the plug status scoped to the input and output streams. This makes no sense for us, so I just ++ * always pick the OUTPUT status. */ ++ if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, sdev->dev_id, ++ AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_JACK_PLUGGED, 0), NULL, 0, &plugged, sizeof(plugged))) { ++ dev_err(sdev->a->dev, "Failed to get jack enable status\n"); ++ return; ++ } ++ dev_dbg(sdev->a->dev, "Jack is now %s\n", plugged ? "plugged" : "unplugged"); ++ snd_jack_report(sdev->jack, plugged ? sdev->jack->type : 0); ++} ++ ++void aaudio_handle_prop_change_work(struct work_struct *ws) ++{ ++ struct aaudio_prop_change_work_struct *work = container_of(ws, struct aaudio_prop_change_work_struct, ws); ++ struct aaudio_subdevice *sdev; ++ ++ sdev = aaudio_find_dev_by_dev_id(work->a, work->dev); ++ if (!sdev) { ++ dev_err(work->a->dev, "Property notification change: device not found\n"); ++ goto done; ++ } ++ dev_dbg(work->a->dev, "Property changed for device: %s\n", sdev->uid); ++ ++ if (work->prop.scope == AAUDIO_PROP_SCOPE_OUTPUT && work->prop.selector == AAUDIO_PROP_JACK_PLUGGED) { ++ aaudio_handle_jack_connection_change(sdev); ++ } ++ ++done: ++ kfree(work); ++} ++ ++void aaudio_handle_prop_change(struct aaudio_device *a, struct aaudio_msg *msg) ++{ ++ /* NOTE: This is a scheduled work because this callback will generally need to query device information and this ++ * is not possible when we are in the reply parsing code's context. */ ++ struct aaudio_prop_change_work_struct *work; ++ work = kmalloc(sizeof(struct aaudio_prop_change_work_struct), GFP_KERNEL); ++ work->a = a; ++ INIT_WORK(&work->ws, aaudio_handle_prop_change_work); ++ aaudio_msg_read_property_changed(msg, &work->dev, &work->obj, &work->prop); ++ schedule_work(&work->ws); ++} ++ ++#define aaudio_send_cmd_response(a, sctx, msg, fn, ...) \ ++ if (aaudio_send_with_tag(a, sctx, ((struct aaudio_msg_header *) msg->data)->tag, 500, fn, ##__VA_ARGS__)) \ ++ pr_err("aaudio: Failed to reply to a command\n"); ++ ++void aaudio_handle_cmd_timestamp(struct aaudio_device *a, struct aaudio_msg *msg) ++{ ++ ktime_t time_os = ktime_get_boottime(); ++ struct aaudio_send_ctx sctx; ++ struct aaudio_subdevice *sdev; ++ u64 devid, timestamp, update_seed; ++ aaudio_msg_read_update_timestamp(msg, &devid, ×tamp, &update_seed); ++ dev_dbg(a->dev, "Received timestamp update for dev=%llx ts=%llx seed=%llx\n", devid, timestamp, update_seed); ++ ++ sdev = aaudio_find_dev_by_dev_id(a, devid); ++ aaudio_handle_timestamp(sdev, time_os, timestamp); ++ ++ aaudio_send_cmd_response(a, &sctx, msg, ++ aaudio_msg_write_update_timestamp_response); ++} ++ ++void aaudio_handle_command(struct aaudio_device *a, struct aaudio_msg *msg) ++{ ++ struct aaudio_msg_base base; ++ if (aaudio_msg_read_base(msg, &base)) ++ return; ++ switch (base.msg) { ++ case AAUDIO_MSG_UPDATE_TIMESTAMP: ++ aaudio_handle_cmd_timestamp(a, msg); ++ break; ++ default: ++ dev_info(a->dev, "Unhandled device command %i", base.msg); ++ break; ++ } ++} ++ ++static struct pci_device_id aaudio_ids[ ] = { ++ { PCI_DEVICE(PCI_VENDOR_ID_APPLE, 0x1803) }, ++ { 0, }, ++}; ++ ++struct dev_pm_ops aaudio_pci_driver_pm = { ++ .suspend = aaudio_suspend, ++ .resume = aaudio_resume ++}; ++struct pci_driver aaudio_pci_driver = { ++ .name = "aaudio", ++ .id_table = aaudio_ids, ++ .probe = aaudio_probe, ++ .remove = aaudio_remove, ++ .driver = { ++ .pm = &aaudio_pci_driver_pm ++ } ++}; ++ ++ ++int aaudio_module_init(void) ++{ ++ int result; ++ if ((result = alloc_chrdev_region(&aaudio_chrdev, 0, 1, "aaudio"))) ++ goto fail_chrdev; ++#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) ++ aaudio_class = class_create(THIS_MODULE, "aaudio"); ++#else ++ aaudio_class = class_create("aaudio"); ++#endif ++ if (IS_ERR(aaudio_class)) { ++ result = PTR_ERR(aaudio_class); ++ goto fail_class; ++ } ++ ++ result = pci_register_driver(&aaudio_pci_driver); ++ if (result) ++ goto fail_drv; ++ return 0; ++ ++fail_drv: ++ pci_unregister_driver(&aaudio_pci_driver); ++fail_class: ++ class_destroy(aaudio_class); ++fail_chrdev: ++ unregister_chrdev_region(aaudio_chrdev, 1); ++ if (!result) ++ result = -EINVAL; ++ return result; ++} ++ ++void aaudio_module_exit(void) ++{ ++ pci_unregister_driver(&aaudio_pci_driver); ++ class_destroy(aaudio_class); ++ unregister_chrdev_region(aaudio_chrdev, 1); ++} ++ ++struct aaudio_alsa_pcm_id_mapping aaudio_alsa_id_mappings[] = { ++ {"Speaker", 0}, ++ {"Digital Mic", 1}, ++ {"Codec Output", 2}, ++ {"Codec Input", 3}, ++ {"Bridge Loopback", 4}, ++ {} ++}; ++ ++module_param_named(index, aaudio_alsa_index, int, 0444); ++MODULE_PARM_DESC(index, "Index value for Apple Internal Audio soundcard."); ++module_param_named(id, aaudio_alsa_id, charp, 0444); ++MODULE_PARM_DESC(id, "ID string for Apple Internal Audio soundcard."); +diff --git a/drivers/staging/apple-bce/audio/audio.h b/drivers/staging/apple-bce/audio/audio.h +new file mode 100644 +index 000000000000..004bc1e22ea4 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/audio.h +@@ -0,0 +1,125 @@ ++#ifndef AAUDIO_H ++#define AAUDIO_H ++ ++#include ++#include ++#include "../apple_bce.h" ++#include "protocol_bce.h" ++#include "description.h" ++ ++#define AAUDIO_SIG 0x19870423 ++ ++#define AAUDIO_DEVICE_MAX_UID_LEN 128 ++#define AAUDIO_DEIVCE_MAX_INPUT_STREAMS 1 ++#define AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS 1 ++#define AAUDIO_DEIVCE_MAX_BUFFER_COUNT 1 ++ ++#define AAUDIO_BUFFER_ID_NONE 0xffu ++ ++struct snd_card; ++struct snd_pcm; ++struct snd_pcm_hardware; ++struct snd_jack; ++ ++struct __attribute__((packed)) __attribute__((aligned(4))) aaudio_buffer_struct_buffer { ++ size_t address; ++ size_t size; ++ size_t pad[4]; ++}; ++struct aaudio_buffer_struct_stream { ++ u8 num_buffers; ++ struct aaudio_buffer_struct_buffer buffers[100]; ++ char filler[32]; ++}; ++struct aaudio_buffer_struct_device { ++ char name[128]; ++ u8 num_input_streams; ++ u8 num_output_streams; ++ struct aaudio_buffer_struct_stream input_streams[5]; ++ struct aaudio_buffer_struct_stream output_streams[5]; ++ char filler[128]; ++}; ++struct aaudio_buffer_struct { ++ u32 version; ++ u32 signature; ++ u32 flags; ++ u8 num_devices; ++ struct aaudio_buffer_struct_device devices[20]; ++}; ++ ++struct aaudio_device; ++struct aaudio_dma_buf { ++ dma_addr_t dma_addr; ++ void *ptr; ++ size_t size; ++}; ++struct aaudio_stream { ++ aaudio_object_id_t id; ++ size_t buffer_cnt; ++ struct aaudio_dma_buf *buffers; ++ ++ struct aaudio_apple_description desc; ++ struct snd_pcm_hardware *alsa_hw_desc; ++ u32 latency; ++ ++ bool waiting_for_first_ts; ++ ++ ktime_t remote_timestamp; ++ snd_pcm_sframes_t frame_min; ++ int started; ++}; ++struct aaudio_subdevice { ++ struct aaudio_device *a; ++ struct list_head list; ++ aaudio_device_id_t dev_id; ++ u32 in_latency, out_latency; ++ u8 buf_id; ++ int alsa_id; ++ char uid[AAUDIO_DEVICE_MAX_UID_LEN + 1]; ++ size_t in_stream_cnt; ++ struct aaudio_stream in_streams[AAUDIO_DEIVCE_MAX_INPUT_STREAMS]; ++ size_t out_stream_cnt; ++ struct aaudio_stream out_streams[AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS]; ++ bool is_pcm; ++ struct snd_pcm *pcm; ++ struct snd_jack *jack; ++}; ++struct aaudio_alsa_pcm_id_mapping { ++ const char *name; ++ int alsa_id; ++}; ++ ++struct aaudio_device { ++ struct pci_dev *pci; ++ dev_t devt; ++ struct device *dev; ++ void __iomem *reg_mem_bs; ++ dma_addr_t reg_mem_bs_dma; ++ void __iomem *reg_mem_cfg; ++ ++ u32 __iomem *reg_mem_gpr; ++ ++ struct aaudio_buffer_struct *bs; ++ ++ struct apple_bce_device *bce; ++ struct aaudio_bce bcem; ++ ++ struct snd_card *card; ++ ++ struct list_head subdevice_list; ++ int next_alsa_id; ++ ++ struct completion remote_alive; ++}; ++ ++void aaudio_handle_notification(struct aaudio_device *a, struct aaudio_msg *msg); ++void aaudio_handle_prop_change_work(struct work_struct *ws); ++void aaudio_handle_cmd_timestamp(struct aaudio_device *a, struct aaudio_msg *msg); ++void aaudio_handle_command(struct aaudio_device *a, struct aaudio_msg *msg); ++ ++int aaudio_module_init(void); ++void aaudio_module_exit(void); ++ ++extern struct aaudio_alsa_pcm_id_mapping aaudio_alsa_id_mappings[]; ++ ++#endif //AAUDIO_H +diff --git a/drivers/staging/apple-bce/audio/description.h b/drivers/staging/apple-bce/audio/description.h +new file mode 100644 +index 000000000000..dfef3ab68f27 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/description.h +@@ -0,0 +1,42 @@ ++#ifndef AAUDIO_DESCRIPTION_H ++#define AAUDIO_DESCRIPTION_H ++ ++#include ++ ++struct aaudio_apple_description { ++ u64 sample_rate_double; ++ u32 format_id; ++ u32 format_flags; ++ u32 bytes_per_packet; ++ u32 frames_per_packet; ++ u32 bytes_per_frame; ++ u32 channels_per_frame; ++ u32 bits_per_channel; ++ u32 reserved; ++}; ++ ++enum { ++ AAUDIO_FORMAT_LPCM = 0x6c70636d // 'lpcm' ++}; ++ ++enum { ++ AAUDIO_FORMAT_FLAG_FLOAT = 1, ++ AAUDIO_FORMAT_FLAG_BIG_ENDIAN = 2, ++ AAUDIO_FORMAT_FLAG_SIGNED = 4, ++ AAUDIO_FORMAT_FLAG_PACKED = 8, ++ AAUDIO_FORMAT_FLAG_ALIGNED_HIGH = 16, ++ AAUDIO_FORMAT_FLAG_NON_INTERLEAVED = 32, ++ AAUDIO_FORMAT_FLAG_NON_MIXABLE = 64 ++}; ++ ++static inline u64 aaudio_double_to_u64(u64 d) ++{ ++ u8 sign = (u8) ((d >> 63) & 1); ++ s32 exp = (s32) ((d >> 52) & 0x7ff) - 1023; ++ u64 fr = d & ((1LL << 52) - 1); ++ if (sign || exp < 0) ++ return 0; ++ return (u64) ((1LL << exp) + (fr >> (52 - exp))); ++} ++ ++#endif //AAUDIO_DESCRIPTION_H +diff --git a/drivers/staging/apple-bce/audio/pcm.c b/drivers/staging/apple-bce/audio/pcm.c +new file mode 100644 +index 000000000000..1026e10a9ac5 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/pcm.c +@@ -0,0 +1,308 @@ ++#include "pcm.h" ++#include "audio.h" ++ ++static u64 aaudio_get_alsa_fmtbit(struct aaudio_apple_description *desc) ++{ ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_FLOAT) { ++ if (desc->bits_per_channel == 32) { ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) ++ return SNDRV_PCM_FMTBIT_FLOAT_BE; ++ else ++ return SNDRV_PCM_FMTBIT_FLOAT_LE; ++ } else if (desc->bits_per_channel == 64) { ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) ++ return SNDRV_PCM_FMTBIT_FLOAT64_BE; ++ else ++ return SNDRV_PCM_FMTBIT_FLOAT64_LE; ++ } else { ++ pr_err("aaudio: unsupported bits per channel for float format: %u\n", desc->bits_per_channel); ++ return 0; ++ } ++ } ++#define DEFINE_BPC_OPTION(val, b) \ ++ case val: \ ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) { \ ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) \ ++ return SNDRV_PCM_FMTBIT_S ## b ## BE; \ ++ else \ ++ return SNDRV_PCM_FMTBIT_U ## b ## BE; \ ++ } else { \ ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) \ ++ return SNDRV_PCM_FMTBIT_S ## b ## LE; \ ++ else \ ++ return SNDRV_PCM_FMTBIT_U ## b ## LE; \ ++ } ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_PACKED) { ++ switch (desc->bits_per_channel) { ++ case 8: ++ case 16: ++ case 32: ++ break; ++ DEFINE_BPC_OPTION(24, 24_3) ++ default: ++ pr_err("aaudio: unsupported bits per channel for packed format: %u\n", desc->bits_per_channel); ++ return 0; ++ } ++ } ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_ALIGNED_HIGH) { ++ switch (desc->bits_per_channel) { ++ DEFINE_BPC_OPTION(24, 32_) ++ default: ++ pr_err("aaudio: unsupported bits per channel for high-aligned format: %u\n", desc->bits_per_channel); ++ return 0; ++ } ++ } ++ switch (desc->bits_per_channel) { ++ case 8: ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) ++ return SNDRV_PCM_FMTBIT_S8; ++ else ++ return SNDRV_PCM_FMTBIT_U8; ++ DEFINE_BPC_OPTION(16, 16_) ++ DEFINE_BPC_OPTION(24, 24_) ++ DEFINE_BPC_OPTION(32, 32_) ++ default: ++ pr_err("aaudio: unsupported bits per channel: %u\n", desc->bits_per_channel); ++ return 0; ++ } ++} ++int aaudio_create_hw_info(struct aaudio_apple_description *desc, struct snd_pcm_hardware *alsa_hw, ++ size_t buf_size) ++{ ++ uint rate; ++ alsa_hw->info = (SNDRV_PCM_INFO_MMAP | ++ SNDRV_PCM_INFO_BLOCK_TRANSFER | ++ SNDRV_PCM_INFO_MMAP_VALID | ++ SNDRV_PCM_INFO_DOUBLE); ++ if (desc->format_flags & AAUDIO_FORMAT_FLAG_NON_MIXABLE) ++ pr_warn("aaudio: unsupported hw flag: NON_MIXABLE\n"); ++ if (!(desc->format_flags & AAUDIO_FORMAT_FLAG_NON_INTERLEAVED)) ++ alsa_hw->info |= SNDRV_PCM_INFO_INTERLEAVED; ++ alsa_hw->formats = aaudio_get_alsa_fmtbit(desc); ++ if (!alsa_hw->formats) ++ return -EINVAL; ++ rate = (uint) aaudio_double_to_u64(desc->sample_rate_double); ++ alsa_hw->rates = snd_pcm_rate_to_rate_bit(rate); ++ alsa_hw->rate_min = rate; ++ alsa_hw->rate_max = rate; ++ alsa_hw->channels_min = desc->channels_per_frame; ++ alsa_hw->channels_max = desc->channels_per_frame; ++ alsa_hw->buffer_bytes_max = buf_size; ++ alsa_hw->period_bytes_min = desc->bytes_per_packet; ++ alsa_hw->period_bytes_max = desc->bytes_per_packet; ++ alsa_hw->periods_min = (uint) (buf_size / desc->bytes_per_packet); ++ alsa_hw->periods_max = (uint) (buf_size / desc->bytes_per_packet); ++ pr_debug("aaudio_create_hw_info: format = %llu, rate = %u/%u. channels = %u, periods = %u, period size = %lu\n", ++ alsa_hw->formats, alsa_hw->rate_min, alsa_hw->rates, alsa_hw->channels_min, alsa_hw->periods_min, ++ alsa_hw->period_bytes_min); ++ return 0; ++} ++ ++static struct aaudio_stream *aaudio_pcm_stream(struct snd_pcm_substream *substream) ++{ ++ struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); ++ if (substream->stream == SNDRV_PCM_STREAM_PLAYBACK) ++ return &sdev->out_streams[substream->number]; ++ else ++ return &sdev->in_streams[substream->number]; ++} ++ ++static int aaudio_pcm_open(struct snd_pcm_substream *substream) ++{ ++ pr_debug("aaudio_pcm_open\n"); ++ substream->runtime->hw = *aaudio_pcm_stream(substream)->alsa_hw_desc; ++ ++ return 0; ++} ++ ++static int aaudio_pcm_close(struct snd_pcm_substream *substream) ++{ ++ pr_debug("aaudio_pcm_close\n"); ++ return 0; ++} ++ ++static int aaudio_pcm_prepare(struct snd_pcm_substream *substream) ++{ ++ return 0; ++} ++ ++static int aaudio_pcm_hw_params(struct snd_pcm_substream *substream, struct snd_pcm_hw_params *hw_params) ++{ ++ struct aaudio_stream *astream = aaudio_pcm_stream(substream); ++ pr_debug("aaudio_pcm_hw_params\n"); ++ ++ if (!astream->buffer_cnt || !astream->buffers) ++ return -EINVAL; ++ ++ substream->runtime->dma_area = astream->buffers[0].ptr; ++ substream->runtime->dma_addr = astream->buffers[0].dma_addr; ++ substream->runtime->dma_bytes = astream->buffers[0].size; ++ return 0; ++} ++ ++static int aaudio_pcm_hw_free(struct snd_pcm_substream *substream) ++{ ++ pr_debug("aaudio_pcm_hw_free\n"); ++ return 0; ++} ++ ++static void aaudio_pcm_start(struct snd_pcm_substream *substream) ++{ ++ struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); ++ struct aaudio_stream *stream = aaudio_pcm_stream(substream); ++ void *buf; ++ size_t s; ++ ktime_t time_start, time_end; ++ bool back_buffer; ++ time_start = ktime_get(); ++ ++ back_buffer = (substream->stream == SNDRV_PCM_STREAM_PLAYBACK); ++ ++ if (back_buffer) { ++ s = frames_to_bytes(substream->runtime, substream->runtime->control->appl_ptr); ++ buf = kmalloc(s, GFP_KERNEL); ++ memcpy_fromio(buf, substream->runtime->dma_area, s); ++ time_end = ktime_get(); ++ pr_debug("aaudio: Backed up the buffer in %lluns [%li]\n", ktime_to_ns(time_end - time_start), ++ substream->runtime->control->appl_ptr); ++ } ++ ++ stream->waiting_for_first_ts = true; ++ stream->frame_min = stream->latency; ++ ++ aaudio_cmd_start_io(sdev->a, sdev->dev_id); ++ if (back_buffer) ++ memcpy_toio(substream->runtime->dma_area, buf, s); ++ ++ time_end = ktime_get(); ++ pr_debug("aaudio: Started the audio device in %lluns\n", ktime_to_ns(time_end - time_start)); ++} ++ ++static int aaudio_pcm_trigger(struct snd_pcm_substream *substream, int cmd) ++{ ++ struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); ++ struct aaudio_stream *stream = aaudio_pcm_stream(substream); ++ pr_debug("aaudio_pcm_trigger %x\n", cmd); ++ ++ /* We only supports triggers on the #0 buffer */ ++ if (substream->number != 0) ++ return 0; ++ switch (cmd) { ++ case SNDRV_PCM_TRIGGER_START: ++ aaudio_pcm_start(substream); ++ stream->started = 1; ++ break; ++ case SNDRV_PCM_TRIGGER_STOP: ++ aaudio_cmd_stop_io(sdev->a, sdev->dev_id); ++ stream->started = 0; ++ break; ++ default: ++ return -EINVAL; ++ } ++ return 0; ++} ++ ++static snd_pcm_uframes_t aaudio_pcm_pointer(struct snd_pcm_substream *substream) ++{ ++ struct aaudio_stream *stream = aaudio_pcm_stream(substream); ++ ktime_t time_from_start; ++ snd_pcm_sframes_t frames; ++ snd_pcm_sframes_t buffer_time_length; ++ ++ if (!stream->started || stream->waiting_for_first_ts) { ++ pr_warn("aaudio_pcm_pointer while not started\n"); ++ return 0; ++ } ++ ++ /* Approximate the pointer based on the last received timestamp */ ++ time_from_start = ktime_get_boottime() - stream->remote_timestamp; ++ buffer_time_length = NSEC_PER_SEC * substream->runtime->buffer_size / substream->runtime->rate; ++ frames = (ktime_to_ns(time_from_start) % buffer_time_length) * substream->runtime->buffer_size / buffer_time_length; ++ if (ktime_to_ns(time_from_start) < buffer_time_length) { ++ if (frames < stream->frame_min) ++ frames = stream->frame_min; ++ else ++ stream->frame_min = 0; ++ } else { ++ if (ktime_to_ns(time_from_start) < 2 * buffer_time_length) ++ stream->frame_min = frames; ++ else ++ stream->frame_min = 0; /* Heavy desync */ ++ } ++ frames -= stream->latency; ++ if (frames < 0) ++ frames += ((-frames - 1) / substream->runtime->buffer_size + 1) * substream->runtime->buffer_size; ++ return (snd_pcm_uframes_t) frames; ++} ++ ++static struct snd_pcm_ops aaudio_pcm_ops = { ++ .open = aaudio_pcm_open, ++ .close = aaudio_pcm_close, ++ .ioctl = snd_pcm_lib_ioctl, ++ .hw_params = aaudio_pcm_hw_params, ++ .hw_free = aaudio_pcm_hw_free, ++ .prepare = aaudio_pcm_prepare, ++ .trigger = aaudio_pcm_trigger, ++ .pointer = aaudio_pcm_pointer, ++ .mmap = snd_pcm_lib_mmap_iomem ++}; ++ ++int aaudio_create_pcm(struct aaudio_subdevice *sdev) ++{ ++ struct snd_pcm *pcm; ++ struct aaudio_alsa_pcm_id_mapping *id_mapping; ++ int err; ++ ++ if (!sdev->is_pcm || (sdev->in_stream_cnt == 0 && sdev->out_stream_cnt == 0)) { ++ return -EINVAL; ++ } ++ ++ for (id_mapping = aaudio_alsa_id_mappings; id_mapping->name; id_mapping++) { ++ if (!strcmp(sdev->uid, id_mapping->name)) { ++ sdev->alsa_id = id_mapping->alsa_id; ++ break; ++ } ++ } ++ if (!id_mapping->name) ++ sdev->alsa_id = sdev->a->next_alsa_id++; ++ err = snd_pcm_new(sdev->a->card, sdev->uid, sdev->alsa_id, ++ (int) sdev->out_stream_cnt, (int) sdev->in_stream_cnt, &pcm); ++ if (err < 0) ++ return err; ++ pcm->private_data = sdev; ++ pcm->nonatomic = 1; ++ sdev->pcm = pcm; ++ strcpy(pcm->name, sdev->uid); ++ snd_pcm_set_ops(pcm, SNDRV_PCM_STREAM_PLAYBACK, &aaudio_pcm_ops); ++ snd_pcm_set_ops(pcm, SNDRV_PCM_STREAM_CAPTURE, &aaudio_pcm_ops); ++ return 0; ++} ++ ++static void aaudio_handle_stream_timestamp(struct snd_pcm_substream *substream, ktime_t timestamp) ++{ ++ unsigned long flags; ++ struct aaudio_stream *stream; ++ ++ stream = aaudio_pcm_stream(substream); ++ snd_pcm_stream_lock_irqsave(substream, flags); ++ stream->remote_timestamp = timestamp; ++ if (stream->waiting_for_first_ts) { ++ stream->waiting_for_first_ts = false; ++ snd_pcm_stream_unlock_irqrestore(substream, flags); ++ return; ++ } ++ snd_pcm_stream_unlock_irqrestore(substream, flags); ++ snd_pcm_period_elapsed(substream); ++} ++ ++void aaudio_handle_timestamp(struct aaudio_subdevice *sdev, ktime_t os_timestamp, u64 dev_timestamp) ++{ ++ struct snd_pcm_substream *substream; ++ ++ substream = sdev->pcm->streams[SNDRV_PCM_STREAM_PLAYBACK].substream; ++ if (substream) ++ aaudio_handle_stream_timestamp(substream, dev_timestamp); ++ substream = sdev->pcm->streams[SNDRV_PCM_STREAM_CAPTURE].substream; ++ if (substream) ++ aaudio_handle_stream_timestamp(substream, os_timestamp); ++} +diff --git a/drivers/staging/apple-bce/audio/pcm.h b/drivers/staging/apple-bce/audio/pcm.h +new file mode 100644 +index 000000000000..ea5f35fbe408 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/pcm.h +@@ -0,0 +1,16 @@ ++#ifndef AAUDIO_PCM_H ++#define AAUDIO_PCM_H ++ ++#include ++#include ++ ++struct aaudio_subdevice; ++struct aaudio_apple_description; ++struct snd_pcm_hardware; ++ ++int aaudio_create_hw_info(struct aaudio_apple_description *desc, struct snd_pcm_hardware *alsa_hw, size_t buf_size); ++int aaudio_create_pcm(struct aaudio_subdevice *sdev); ++ ++void aaudio_handle_timestamp(struct aaudio_subdevice *sdev, ktime_t os_timestamp, u64 dev_timestamp); ++ ++#endif //AAUDIO_PCM_H +diff --git a/drivers/staging/apple-bce/audio/protocol.c b/drivers/staging/apple-bce/audio/protocol.c +new file mode 100644 +index 000000000000..2314813aeead +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/protocol.c +@@ -0,0 +1,347 @@ ++#include "protocol.h" ++#include "protocol_bce.h" ++#include "audio.h" ++ ++int aaudio_msg_read_base(struct aaudio_msg *msg, struct aaudio_msg_base *base) ++{ ++ if (msg->size < sizeof(struct aaudio_msg_header) + sizeof(struct aaudio_msg_base) * 2) ++ return -EINVAL; ++ *base = *((struct aaudio_msg_base *) ((struct aaudio_msg_header *) msg->data + 1)); ++ return 0; ++} ++ ++#define READ_START(type) \ ++ size_t offset = sizeof(struct aaudio_msg_header) + sizeof(struct aaudio_msg_base); (void)offset; \ ++ if (((struct aaudio_msg_base *) ((struct aaudio_msg_header *) msg->data + 1))->msg != type) \ ++ return -EINVAL; ++#define READ_DEVID_VAR(devid) *devid = ((struct aaudio_msg_header *) msg->data)->device_id ++#define READ_VAL(type) ({ offset += sizeof(type); *((type *) ((u8 *) msg->data + offset - sizeof(type))); }) ++#define READ_VAR(type, var) *var = READ_VAL(type) ++ ++int aaudio_msg_read_start_io_response(struct aaudio_msg *msg) ++{ ++ READ_START(AAUDIO_MSG_START_IO_RESPONSE); ++ return 0; ++} ++ ++int aaudio_msg_read_stop_io_response(struct aaudio_msg *msg) ++{ ++ READ_START(AAUDIO_MSG_STOP_IO_RESPONSE); ++ return 0; ++} ++ ++int aaudio_msg_read_update_timestamp(struct aaudio_msg *msg, aaudio_device_id_t *devid, ++ u64 *timestamp, u64 *update_seed) ++{ ++ READ_START(AAUDIO_MSG_UPDATE_TIMESTAMP); ++ READ_DEVID_VAR(devid); ++ READ_VAR(u64, timestamp); ++ READ_VAR(u64, update_seed); ++ return 0; ++} ++ ++int aaudio_msg_read_get_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, ++ struct aaudio_prop_addr *prop, void **data, u64 *data_size) ++{ ++ READ_START(AAUDIO_MSG_GET_PROPERTY_RESPONSE); ++ READ_VAR(aaudio_object_id_t, obj); ++ READ_VAR(u32, &prop->element); ++ READ_VAR(u32, &prop->scope); ++ READ_VAR(u32, &prop->selector); ++ READ_VAR(u64, data_size); ++ *data = ((u8 *) msg->data + offset); ++ /* offset += data_size; */ ++ return 0; ++} ++ ++int aaudio_msg_read_set_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj) ++{ ++ READ_START(AAUDIO_MSG_SET_PROPERTY_RESPONSE); ++ READ_VAR(aaudio_object_id_t, obj); ++ return 0; ++} ++ ++int aaudio_msg_read_property_listener_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, ++ struct aaudio_prop_addr *prop) ++{ ++ READ_START(AAUDIO_MSG_PROPERTY_LISTENER_RESPONSE); ++ READ_VAR(aaudio_object_id_t, obj); ++ READ_VAR(u32, &prop->element); ++ READ_VAR(u32, &prop->scope); ++ READ_VAR(u32, &prop->selector); ++ return 0; ++} ++ ++int aaudio_msg_read_property_changed(struct aaudio_msg *msg, aaudio_device_id_t *devid, aaudio_object_id_t *obj, ++ struct aaudio_prop_addr *prop) ++{ ++ READ_START(AAUDIO_MSG_PROPERTY_CHANGED); ++ READ_DEVID_VAR(devid); ++ READ_VAR(aaudio_object_id_t, obj); ++ READ_VAR(u32, &prop->element); ++ READ_VAR(u32, &prop->scope); ++ READ_VAR(u32, &prop->selector); ++ return 0; ++} ++ ++int aaudio_msg_read_set_input_stream_address_ranges_response(struct aaudio_msg *msg) ++{ ++ READ_START(AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES_RESPONSE); ++ return 0; ++} ++ ++int aaudio_msg_read_get_input_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt) ++{ ++ READ_START(AAUDIO_MSG_GET_INPUT_STREAM_LIST_RESPONSE); ++ READ_VAR(u64, str_cnt); ++ *str_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); ++ /* offset += str_cnt * sizeof(aaudio_object_id_t); */ ++ return 0; ++} ++ ++int aaudio_msg_read_get_output_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt) ++{ ++ READ_START(AAUDIO_MSG_GET_OUTPUT_STREAM_LIST_RESPONSE); ++ READ_VAR(u64, str_cnt); ++ *str_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); ++ /* offset += str_cnt * sizeof(aaudio_object_id_t); */ ++ return 0; ++} ++ ++int aaudio_msg_read_set_remote_access_response(struct aaudio_msg *msg) ++{ ++ READ_START(AAUDIO_MSG_SET_REMOTE_ACCESS_RESPONSE); ++ return 0; ++} ++ ++int aaudio_msg_read_get_device_list_response(struct aaudio_msg *msg, aaudio_device_id_t **dev_l, u64 *dev_cnt) ++{ ++ READ_START(AAUDIO_MSG_GET_DEVICE_LIST_RESPONSE); ++ READ_VAR(u64, dev_cnt); ++ *dev_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); ++ /* offset += dev_cnt * sizeof(aaudio_device_id_t); */ ++ return 0; ++} ++ ++#define WRITE_START_OF_TYPE(typev, devid) \ ++ size_t offset = sizeof(struct aaudio_msg_header); (void) offset; \ ++ ((struct aaudio_msg_header *) msg->data)->type = (typev); \ ++ ((struct aaudio_msg_header *) msg->data)->device_id = (devid); ++#define WRITE_START_COMMAND(devid) WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_COMMAND, devid) ++#define WRITE_START_RESPONSE() WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_RESPONSE, 0) ++#define WRITE_START_NOTIFICATION() WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_NOTIFICATION, 0) ++#define WRITE_VAL(type, value) { *((type *) ((u8 *) msg->data + offset)) = value; offset += sizeof(value); } ++#define WRITE_BIN(value, size) { memcpy((u8 *) msg->data + offset, value, size); offset += size; } ++#define WRITE_BASE(type) WRITE_VAL(u32, type) WRITE_VAL(u32, 0) ++#define WRITE_END() { msg->size = offset; } ++ ++void aaudio_msg_write_start_io(struct aaudio_msg *msg, aaudio_device_id_t dev) ++{ ++ WRITE_START_COMMAND(dev); ++ WRITE_BASE(AAUDIO_MSG_START_IO); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_stop_io(struct aaudio_msg *msg, aaudio_device_id_t dev) ++{ ++ WRITE_START_COMMAND(dev); ++ WRITE_BASE(AAUDIO_MSG_STOP_IO); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_get_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size) ++{ ++ WRITE_START_COMMAND(dev); ++ WRITE_BASE(AAUDIO_MSG_GET_PROPERTY); ++ WRITE_VAL(aaudio_object_id_t, obj); ++ WRITE_VAL(u32, prop.element); ++ WRITE_VAL(u32, prop.scope); ++ WRITE_VAL(u32, prop.selector); ++ WRITE_VAL(u64, qualifier_size); ++ WRITE_BIN(qualifier, qualifier_size); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_set_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *data, u64 data_size, void *qualifier, u64 qualifier_size) ++{ ++ WRITE_START_COMMAND(dev); ++ WRITE_BASE(AAUDIO_MSG_SET_PROPERTY); ++ WRITE_VAL(aaudio_object_id_t, obj); ++ WRITE_VAL(u32, prop.element); ++ WRITE_VAL(u32, prop.scope); ++ WRITE_VAL(u32, prop.selector); ++ WRITE_VAL(u64, data_size); ++ WRITE_BIN(data, data_size); ++ WRITE_VAL(u64, qualifier_size); ++ WRITE_BIN(qualifier, qualifier_size); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_property_listener(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop) ++{ ++ WRITE_START_COMMAND(dev); ++ WRITE_BASE(AAUDIO_MSG_PROPERTY_LISTENER); ++ WRITE_VAL(aaudio_object_id_t, obj); ++ WRITE_VAL(u32, prop.element); ++ WRITE_VAL(u32, prop.scope); ++ WRITE_VAL(u32, prop.selector); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_set_input_stream_address_ranges(struct aaudio_msg *msg, aaudio_device_id_t devid) ++{ ++ WRITE_START_COMMAND(devid); ++ WRITE_BASE(AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_get_input_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid) ++{ ++ WRITE_START_COMMAND(devid); ++ WRITE_BASE(AAUDIO_MSG_GET_INPUT_STREAM_LIST); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_get_output_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid) ++{ ++ WRITE_START_COMMAND(devid); ++ WRITE_BASE(AAUDIO_MSG_GET_OUTPUT_STREAM_LIST); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_set_remote_access(struct aaudio_msg *msg, u64 mode) ++{ ++ WRITE_START_COMMAND(0); ++ WRITE_BASE(AAUDIO_MSG_SET_REMOTE_ACCESS); ++ WRITE_VAL(u64, mode); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_alive_notification(struct aaudio_msg *msg, u32 proto_ver, u32 msg_ver) ++{ ++ WRITE_START_NOTIFICATION(); ++ WRITE_BASE(AAUDIO_MSG_NOTIFICATION_ALIVE); ++ WRITE_VAL(u32, proto_ver); ++ WRITE_VAL(u32, msg_ver); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_update_timestamp_response(struct aaudio_msg *msg) ++{ ++ WRITE_START_RESPONSE(); ++ WRITE_BASE(AAUDIO_MSG_UPDATE_TIMESTAMP_RESPONSE); ++ WRITE_END(); ++} ++ ++void aaudio_msg_write_get_device_list(struct aaudio_msg *msg) ++{ ++ WRITE_START_COMMAND(0); ++ WRITE_BASE(AAUDIO_MSG_GET_DEVICE_LIST); ++ WRITE_END(); ++} ++ ++#define CMD_SHARED_VARS_NO_REPLY \ ++ int status = 0; \ ++ struct aaudio_send_ctx sctx; ++#define CMD_SHARED_VARS \ ++ CMD_SHARED_VARS_NO_REPLY \ ++ struct aaudio_msg reply = aaudio_reply_alloc(); \ ++ struct aaudio_msg *buf = &reply; ++#define CMD_SEND_REQUEST(fn, ...) \ ++ if ((status = aaudio_send_cmd_sync(a, &sctx, buf, 500, fn, ##__VA_ARGS__))) \ ++ return status; ++#define CMD_DEF_SHARED_AND_SEND(fn, ...) \ ++ CMD_SHARED_VARS \ ++ CMD_SEND_REQUEST(fn, ##__VA_ARGS__); ++#define CMD_DEF_SHARED_NO_REPLY_AND_SEND(fn, ...) \ ++ CMD_SHARED_VARS_NO_REPLY \ ++ CMD_SEND_REQUEST(fn, ##__VA_ARGS__); ++#define CMD_HNDL_REPLY_NO_FREE(fn, ...) \ ++ status = fn(buf, ##__VA_ARGS__); \ ++ return status; ++#define CMD_HNDL_REPLY_AND_FREE(fn, ...) \ ++ status = fn(buf, ##__VA_ARGS__); \ ++ aaudio_reply_free(&reply); \ ++ return status; ++ ++int aaudio_cmd_start_io(struct aaudio_device *a, aaudio_device_id_t devid) ++{ ++ CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_start_io, devid); ++ CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_start_io_response); ++} ++int aaudio_cmd_stop_io(struct aaudio_device *a, aaudio_device_id_t devid) ++{ ++ CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_stop_io, devid); ++ CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_stop_io_response); ++} ++int aaudio_cmd_get_property(struct aaudio_device *a, struct aaudio_msg *buf, ++ aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void **data, u64 *data_size) ++{ ++ CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_property, devid, obj, prop, qualifier, qualifier_size); ++ CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_property_response, &obj, &prop, data, data_size); ++} ++int aaudio_cmd_get_primitive_property(struct aaudio_device *a, ++ aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size) ++{ ++ int status; ++ struct aaudio_msg reply = aaudio_reply_alloc(); ++ void *r_data; ++ u64 r_data_size; ++ if ((status = aaudio_cmd_get_property(a, &reply, devid, obj, prop, qualifier, qualifier_size, ++ &r_data, &r_data_size))) ++ goto finish; ++ if (r_data_size != data_size) { ++ status = -EINVAL; ++ goto finish; ++ } ++ memcpy(data, r_data, data_size); ++finish: ++ aaudio_reply_free(&reply); ++ return status; ++} ++int aaudio_cmd_set_property(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size) ++{ ++ CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_property, devid, obj, prop, data, data_size, ++ qualifier, qualifier_size); ++ CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_property_response, &obj); ++} ++int aaudio_cmd_property_listener(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop) ++{ ++ CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_property_listener, devid, obj, prop); ++ CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_property_listener_response, &obj, &prop); ++} ++int aaudio_cmd_set_input_stream_address_ranges(struct aaudio_device *a, aaudio_device_id_t devid) ++{ ++ CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_input_stream_address_ranges, devid); ++ CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_input_stream_address_ranges_response); ++} ++int aaudio_cmd_get_input_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, ++ aaudio_object_id_t **str_l, u64 *str_cnt) ++{ ++ CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_input_stream_list, devid); ++ CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_input_stream_list_response, str_l, str_cnt); ++} ++int aaudio_cmd_get_output_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, ++ aaudio_object_id_t **str_l, u64 *str_cnt) ++{ ++ CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_output_stream_list, devid); ++ CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_output_stream_list_response, str_l, str_cnt); ++} ++int aaudio_cmd_set_remote_access(struct aaudio_device *a, u64 mode) ++{ ++ CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_remote_access, mode); ++ CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_remote_access_response); ++} ++int aaudio_cmd_get_device_list(struct aaudio_device *a, struct aaudio_msg *buf, ++ aaudio_device_id_t **dev_l, u64 *dev_cnt) ++{ ++ CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_device_list); ++ CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_device_list_response, dev_l, dev_cnt); ++} +\ No newline at end of file +diff --git a/drivers/staging/apple-bce/audio/protocol.h b/drivers/staging/apple-bce/audio/protocol.h +new file mode 100644 +index 000000000000..3427486f3f57 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/protocol.h +@@ -0,0 +1,147 @@ ++#ifndef AAUDIO_PROTOCOL_H ++#define AAUDIO_PROTOCOL_H ++ ++#include ++ ++struct aaudio_device; ++ ++typedef u64 aaudio_device_id_t; ++typedef u64 aaudio_object_id_t; ++ ++struct aaudio_msg { ++ void *data; ++ size_t size; ++}; ++ ++struct __attribute__((packed)) aaudio_msg_header { ++ char tag[4]; ++ u8 type; ++ aaudio_device_id_t device_id; // Idk, use zero for commands? ++}; ++struct __attribute__((packed)) aaudio_msg_base { ++ u32 msg; ++ u32 status; ++}; ++ ++struct aaudio_prop_addr { ++ u32 scope; ++ u32 selector; ++ u32 element; ++}; ++#define AAUDIO_PROP(scope, sel, el) (struct aaudio_prop_addr) { scope, sel, el } ++ ++enum { ++ AAUDIO_MSG_TYPE_COMMAND = 1, ++ AAUDIO_MSG_TYPE_RESPONSE = 2, ++ AAUDIO_MSG_TYPE_NOTIFICATION = 3 ++}; ++ ++enum { ++ AAUDIO_MSG_START_IO = 0, ++ AAUDIO_MSG_START_IO_RESPONSE = 1, ++ AAUDIO_MSG_STOP_IO = 2, ++ AAUDIO_MSG_STOP_IO_RESPONSE = 3, ++ AAUDIO_MSG_UPDATE_TIMESTAMP = 4, ++ AAUDIO_MSG_GET_PROPERTY = 7, ++ AAUDIO_MSG_GET_PROPERTY_RESPONSE = 8, ++ AAUDIO_MSG_SET_PROPERTY = 9, ++ AAUDIO_MSG_SET_PROPERTY_RESPONSE = 10, ++ AAUDIO_MSG_PROPERTY_LISTENER = 11, ++ AAUDIO_MSG_PROPERTY_LISTENER_RESPONSE = 12, ++ AAUDIO_MSG_PROPERTY_CHANGED = 13, ++ AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES = 18, ++ AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES_RESPONSE = 19, ++ AAUDIO_MSG_GET_INPUT_STREAM_LIST = 24, ++ AAUDIO_MSG_GET_INPUT_STREAM_LIST_RESPONSE = 25, ++ AAUDIO_MSG_GET_OUTPUT_STREAM_LIST = 26, ++ AAUDIO_MSG_GET_OUTPUT_STREAM_LIST_RESPONSE = 27, ++ AAUDIO_MSG_SET_REMOTE_ACCESS = 32, ++ AAUDIO_MSG_SET_REMOTE_ACCESS_RESPONSE = 33, ++ AAUDIO_MSG_UPDATE_TIMESTAMP_RESPONSE = 34, ++ ++ AAUDIO_MSG_NOTIFICATION_ALIVE = 100, ++ AAUDIO_MSG_GET_DEVICE_LIST = 101, ++ AAUDIO_MSG_GET_DEVICE_LIST_RESPONSE = 102, ++ AAUDIO_MSG_NOTIFICATION_BOOT = 104 ++}; ++ ++enum { ++ AAUDIO_REMOTE_ACCESS_OFF = 0, ++ AAUDIO_REMOTE_ACCESS_ON = 2 ++}; ++ ++enum { ++ AAUDIO_PROP_SCOPE_GLOBAL = 0x676c6f62, // 'glob' ++ AAUDIO_PROP_SCOPE_INPUT = 0x696e7074, // 'inpt' ++ AAUDIO_PROP_SCOPE_OUTPUT = 0x6f757470 // 'outp' ++}; ++ ++enum { ++ AAUDIO_PROP_UID = 0x75696420, // 'uid ' ++ AAUDIO_PROP_BOOL_VALUE = 0x6263766c, // 'bcvl' ++ AAUDIO_PROP_JACK_PLUGGED = 0x6a61636b, // 'jack' ++ AAUDIO_PROP_SEL_VOLUME = 0x64656176, // 'deav' ++ AAUDIO_PROP_LATENCY = 0x6c746e63, // 'ltnc' ++ AAUDIO_PROP_PHYS_FORMAT = 0x70667420 // 'pft ' ++}; ++ ++int aaudio_msg_read_base(struct aaudio_msg *msg, struct aaudio_msg_base *base); ++ ++int aaudio_msg_read_start_io_response(struct aaudio_msg *msg); ++int aaudio_msg_read_stop_io_response(struct aaudio_msg *msg); ++int aaudio_msg_read_update_timestamp(struct aaudio_msg *msg, aaudio_device_id_t *devid, ++ u64 *timestamp, u64 *update_seed); ++int aaudio_msg_read_get_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, ++ struct aaudio_prop_addr *prop, void **data, u64 *data_size); ++int aaudio_msg_read_set_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj); ++int aaudio_msg_read_property_listener_response(struct aaudio_msg *msg,aaudio_object_id_t *obj, ++ struct aaudio_prop_addr *prop); ++int aaudio_msg_read_property_changed(struct aaudio_msg *msg, aaudio_device_id_t *devid, aaudio_object_id_t *obj, ++ struct aaudio_prop_addr *prop); ++int aaudio_msg_read_set_input_stream_address_ranges_response(struct aaudio_msg *msg); ++int aaudio_msg_read_get_input_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt); ++int aaudio_msg_read_get_output_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt); ++int aaudio_msg_read_set_remote_access_response(struct aaudio_msg *msg); ++int aaudio_msg_read_get_device_list_response(struct aaudio_msg *msg, aaudio_device_id_t **dev_l, u64 *dev_cnt); ++ ++void aaudio_msg_write_start_io(struct aaudio_msg *msg, aaudio_device_id_t dev); ++void aaudio_msg_write_stop_io(struct aaudio_msg *msg, aaudio_device_id_t dev); ++void aaudio_msg_write_get_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size); ++void aaudio_msg_write_set_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *data, u64 data_size, void *qualifier, u64 qualifier_size); ++void aaudio_msg_write_property_listener(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop); ++void aaudio_msg_write_set_input_stream_address_ranges(struct aaudio_msg *msg, aaudio_device_id_t devid); ++void aaudio_msg_write_get_input_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid); ++void aaudio_msg_write_get_output_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid); ++void aaudio_msg_write_set_remote_access(struct aaudio_msg *msg, u64 mode); ++void aaudio_msg_write_alive_notification(struct aaudio_msg *msg, u32 proto_ver, u32 msg_ver); ++void aaudio_msg_write_update_timestamp_response(struct aaudio_msg *msg); ++void aaudio_msg_write_get_device_list(struct aaudio_msg *msg); ++ ++ ++int aaudio_cmd_start_io(struct aaudio_device *a, aaudio_device_id_t devid); ++int aaudio_cmd_stop_io(struct aaudio_device *a, aaudio_device_id_t devid); ++int aaudio_cmd_get_property(struct aaudio_device *a, struct aaudio_msg *buf, ++ aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void **data, u64 *data_size); ++int aaudio_cmd_get_primitive_property(struct aaudio_device *a, ++ aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size); ++int aaudio_cmd_set_property(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size); ++int aaudio_cmd_property_listener(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, ++ struct aaudio_prop_addr prop); ++int aaudio_cmd_set_input_stream_address_ranges(struct aaudio_device *a, aaudio_device_id_t devid); ++int aaudio_cmd_get_input_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, ++ aaudio_object_id_t **str_l, u64 *str_cnt); ++int aaudio_cmd_get_output_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, ++ aaudio_object_id_t **str_l, u64 *str_cnt); ++int aaudio_cmd_set_remote_access(struct aaudio_device *a, u64 mode); ++int aaudio_cmd_get_device_list(struct aaudio_device *a, struct aaudio_msg *buf, ++ aaudio_device_id_t **dev_l, u64 *dev_cnt); ++ ++ ++ ++#endif //AAUDIO_PROTOCOL_H +diff --git a/drivers/staging/apple-bce/audio/protocol_bce.c b/drivers/staging/apple-bce/audio/protocol_bce.c +new file mode 100644 +index 000000000000..28f2dfd44d67 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/protocol_bce.c +@@ -0,0 +1,226 @@ ++#include "protocol_bce.h" ++ ++#include "audio.h" ++ ++static void aaudio_bce_out_queue_completion(struct bce_queue_sq *sq); ++static void aaudio_bce_in_queue_completion(struct bce_queue_sq *sq); ++static int aaudio_bce_queue_init(struct aaudio_device *dev, struct aaudio_bce_queue *q, const char *name, int direction, ++ bce_sq_completion cfn); ++void aaudio_bce_in_queue_submit_pending(struct aaudio_bce_queue *q, size_t count); ++ ++int aaudio_bce_init(struct aaudio_device *dev) ++{ ++ int status; ++ struct aaudio_bce *bce = &dev->bcem; ++ bce->cq = bce_create_cq(dev->bce, 0x80); ++ spin_lock_init(&bce->spinlock); ++ if (!bce->cq) ++ return -EINVAL; ++ if ((status = aaudio_bce_queue_init(dev, &bce->qout, "com.apple.BridgeAudio.IntelToARM", DMA_TO_DEVICE, ++ aaudio_bce_out_queue_completion))) { ++ return status; ++ } ++ if ((status = aaudio_bce_queue_init(dev, &bce->qin, "com.apple.BridgeAudio.ARMToIntel", DMA_FROM_DEVICE, ++ aaudio_bce_in_queue_completion))) { ++ return status; ++ } ++ aaudio_bce_in_queue_submit_pending(&bce->qin, bce->qin.el_count); ++ return 0; ++} ++ ++int aaudio_bce_queue_init(struct aaudio_device *dev, struct aaudio_bce_queue *q, const char *name, int direction, ++ bce_sq_completion cfn) ++{ ++ q->cq = dev->bcem.cq; ++ q->el_size = AAUDIO_BCE_QUEUE_ELEMENT_SIZE; ++ q->el_count = AAUDIO_BCE_QUEUE_ELEMENT_COUNT; ++ /* NOTE: The Apple impl uses 0x80 as the queue size, however we use 21 (in fact 20) to simplify the impl */ ++ q->sq = bce_create_sq(dev->bce, q->cq, name, (u32) (q->el_count + 1), direction, cfn, dev); ++ if (!q->sq) ++ return -EINVAL; ++ ++ q->data = dma_alloc_coherent(&dev->bce->pci->dev, q->el_size * q->el_count, &q->dma_addr, GFP_KERNEL); ++ if (!q->data) { ++ bce_destroy_sq(dev->bce, q->sq); ++ return -EINVAL; ++ } ++ return 0; ++} ++ ++static void aaudio_send_create_tag(struct aaudio_bce *b, int *tagn, char tag[4]) ++{ ++ char tag_zero[5]; ++ b->tag_num = (b->tag_num + 1) % AAUDIO_BCE_QUEUE_TAG_COUNT; ++ *tagn = b->tag_num; ++ snprintf(tag_zero, 5, "S%03d", b->tag_num); ++ *((u32 *) tag) = *((u32 *) tag_zero); ++} ++ ++int __aaudio_send_prepare(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, char *tag) ++{ ++ int status; ++ size_t index; ++ void *dptr; ++ struct aaudio_msg_header *header; ++ if ((status = bce_reserve_submission(b->qout.sq, &ctx->timeout))) ++ return status; ++ spin_lock_irqsave(&b->spinlock, ctx->irq_flags); ++ index = b->qout.data_tail; ++ dptr = (u8 *) b->qout.data + index * b->qout.el_size; ++ ctx->msg.data = dptr; ++ header = dptr; ++ if (tag) ++ *((u32 *) header->tag) = *((u32 *) tag); ++ else ++ aaudio_send_create_tag(b, &ctx->tag_n, header->tag); ++ return 0; ++} ++ ++void __aaudio_send(struct aaudio_bce *b, struct aaudio_send_ctx *ctx) ++{ ++ struct bce_qe_submission *s = bce_next_submission(b->qout.sq); ++#ifdef DEBUG ++ pr_debug("aaudio: Sending command data\n"); ++ print_hex_dump(KERN_DEBUG, "aaudio:OUT ", DUMP_PREFIX_NONE, 32, 1, ctx->msg.data, ctx->msg.size, true); ++#endif ++ bce_set_submission_single(s, b->qout.dma_addr + (dma_addr_t) (ctx->msg.data - b->qout.data), ctx->msg.size); ++ bce_submit_to_device(b->qout.sq); ++ b->qout.data_tail = (b->qout.data_tail + 1) % b->qout.el_count; ++ spin_unlock_irqrestore(&b->spinlock, ctx->irq_flags); ++} ++ ++int __aaudio_send_cmd_sync(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, struct aaudio_msg *reply) ++{ ++ struct aaudio_bce_queue_entry ent; ++ DECLARE_COMPLETION_ONSTACK(cmpl); ++ ent.msg = reply; ++ ent.cmpl = &cmpl; ++ b->pending_entries[ctx->tag_n] = &ent; ++ __aaudio_send(b, ctx); /* unlocks the spinlock */ ++ ctx->timeout = wait_for_completion_timeout(&cmpl, ctx->timeout); ++ if (ctx->timeout == 0) { ++ /* Remove the pending queue entry; this will be normally handled by the completion route but ++ * during a timeout it won't */ ++ spin_lock_irqsave(&b->spinlock, ctx->irq_flags); ++ if (b->pending_entries[ctx->tag_n] == &ent) ++ b->pending_entries[ctx->tag_n] = NULL; ++ spin_unlock_irqrestore(&b->spinlock, ctx->irq_flags); ++ return -ETIMEDOUT; ++ } ++ return 0; ++} ++ ++static void aaudio_handle_reply(struct aaudio_bce *b, struct aaudio_msg *reply) ++{ ++ const char *tag; ++ int tagn; ++ unsigned long irq_flags; ++ char tag_zero[5]; ++ struct aaudio_bce_queue_entry *entry; ++ ++ tag = ((struct aaudio_msg_header *) reply->data)->tag; ++ if (tag[0] != 'S') { ++ pr_err("aaudio_handle_reply: Unexpected tag: %.4s\n", tag); ++ return; ++ } ++ *((u32 *) tag_zero) = *((u32 *) tag); ++ tag_zero[4] = 0; ++ if (kstrtoint(&tag_zero[1], 10, &tagn)) { ++ pr_err("aaudio_handle_reply: Tag parse failed: %.4s\n", tag); ++ return; ++ } ++ ++ spin_lock_irqsave(&b->spinlock, irq_flags); ++ entry = b->pending_entries[tagn]; ++ if (entry) { ++ if (reply->size < entry->msg->size) ++ entry->msg->size = reply->size; ++ memcpy(entry->msg->data, reply->data, entry->msg->size); ++ complete(entry->cmpl); ++ ++ b->pending_entries[tagn] = NULL; ++ } else { ++ pr_err("aaudio_handle_reply: No queued item found for tag: %.4s\n", tag); ++ } ++ spin_unlock_irqrestore(&b->spinlock, irq_flags); ++} ++ ++static void aaudio_bce_out_queue_completion(struct bce_queue_sq *sq) ++{ ++ while (bce_next_completion(sq)) { ++ //pr_info("aaudio: Send confirmed\n"); ++ bce_notify_submission_complete(sq); ++ } ++} ++ ++static void aaudio_bce_in_queue_handle_msg(struct aaudio_device *a, struct aaudio_msg *msg); ++ ++static void aaudio_bce_in_queue_completion(struct bce_queue_sq *sq) ++{ ++ struct aaudio_msg msg; ++ struct aaudio_device *dev = sq->userdata; ++ struct aaudio_bce_queue *q = &dev->bcem.qin; ++ struct bce_sq_completion_data *c; ++ size_t cnt = 0; ++ ++ mb(); ++ while ((c = bce_next_completion(sq))) { ++ msg.data = (u8 *) q->data + q->data_head * q->el_size; ++ msg.size = c->data_size; ++#ifdef DEBUG ++ pr_debug("aaudio: Received command data %llx\n", c->data_size); ++ print_hex_dump(KERN_DEBUG, "aaudio:IN ", DUMP_PREFIX_NONE, 32, 1, msg.data, min(msg.size, 128UL), true); ++#endif ++ aaudio_bce_in_queue_handle_msg(dev, &msg); ++ ++ q->data_head = (q->data_head + 1) % q->el_count; ++ ++ bce_notify_submission_complete(sq); ++ ++cnt; ++ } ++ aaudio_bce_in_queue_submit_pending(q, cnt); ++} ++ ++static void aaudio_bce_in_queue_handle_msg(struct aaudio_device *a, struct aaudio_msg *msg) ++{ ++ struct aaudio_msg_header *header = (struct aaudio_msg_header *) msg->data; ++ if (msg->size < sizeof(struct aaudio_msg_header)) { ++ pr_err("aaudio: Msg size smaller than header (%lx)", msg->size); ++ return; ++ } ++ if (header->type == AAUDIO_MSG_TYPE_RESPONSE) { ++ aaudio_handle_reply(&a->bcem, msg); ++ } else if (header->type == AAUDIO_MSG_TYPE_COMMAND) { ++ aaudio_handle_command(a, msg); ++ } else if (header->type == AAUDIO_MSG_TYPE_NOTIFICATION) { ++ aaudio_handle_notification(a, msg); ++ } ++} ++ ++void aaudio_bce_in_queue_submit_pending(struct aaudio_bce_queue *q, size_t count) ++{ ++ struct bce_qe_submission *s; ++ while (count--) { ++ if (bce_reserve_submission(q->sq, NULL)) { ++ pr_err("aaudio: Failed to reserve an event queue submission\n"); ++ break; ++ } ++ s = bce_next_submission(q->sq); ++ bce_set_submission_single(s, q->dma_addr + (dma_addr_t) (q->data_tail * q->el_size), q->el_size); ++ q->data_tail = (q->data_tail + 1) % q->el_count; ++ } ++ bce_submit_to_device(q->sq); ++} ++ ++struct aaudio_msg aaudio_reply_alloc(void) ++{ ++ struct aaudio_msg ret; ++ ret.size = AAUDIO_BCE_QUEUE_ELEMENT_SIZE; ++ ret.data = kmalloc(ret.size, GFP_KERNEL); ++ return ret; ++} ++ ++void aaudio_reply_free(struct aaudio_msg *reply) ++{ ++ kfree(reply->data); ++} +diff --git a/drivers/staging/apple-bce/audio/protocol_bce.h b/drivers/staging/apple-bce/audio/protocol_bce.h +new file mode 100644 +index 000000000000..14d26c05ddf9 +--- /dev/null ++++ b/drivers/staging/apple-bce/audio/protocol_bce.h +@@ -0,0 +1,72 @@ ++#ifndef AAUDIO_PROTOCOL_BCE_H ++#define AAUDIO_PROTOCOL_BCE_H ++ ++#include "protocol.h" ++#include "../queue.h" ++ ++#define AAUDIO_BCE_QUEUE_ELEMENT_SIZE 0x1000 ++#define AAUDIO_BCE_QUEUE_ELEMENT_COUNT 20 ++ ++#define AAUDIO_BCE_QUEUE_TAG_COUNT 1000 ++ ++struct aaudio_device; ++ ++struct aaudio_bce_queue_entry { ++ struct aaudio_msg *msg; ++ struct completion *cmpl; ++}; ++struct aaudio_bce_queue { ++ struct bce_queue_cq *cq; ++ struct bce_queue_sq *sq; ++ void *data; ++ dma_addr_t dma_addr; ++ size_t data_head, data_tail; ++ size_t el_size, el_count; ++}; ++struct aaudio_bce { ++ struct bce_queue_cq *cq; ++ struct aaudio_bce_queue qin; ++ struct aaudio_bce_queue qout; ++ int tag_num; ++ struct aaudio_bce_queue_entry *pending_entries[AAUDIO_BCE_QUEUE_TAG_COUNT]; ++ struct spinlock spinlock; ++}; ++ ++struct aaudio_send_ctx { ++ int status; ++ int tag_n; ++ unsigned long irq_flags; ++ struct aaudio_msg msg; ++ unsigned long timeout; ++}; ++ ++int aaudio_bce_init(struct aaudio_device *dev); ++int __aaudio_send_prepare(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, char *tag); ++void __aaudio_send(struct aaudio_bce *b, struct aaudio_send_ctx *ctx); ++int __aaudio_send_cmd_sync(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, struct aaudio_msg *reply); ++ ++#define aaudio_send_with_tag(a, ctx, tag, tout, fn, ...) ({ \ ++ (ctx)->timeout = msecs_to_jiffies(tout); \ ++ (ctx)->status = __aaudio_send_prepare(&(a)->bcem, (ctx), (tag)); \ ++ if (!(ctx)->status) { \ ++ fn(&(ctx)->msg, ##__VA_ARGS__); \ ++ __aaudio_send(&(a)->bcem, (ctx)); \ ++ } \ ++ (ctx)->status; \ ++}) ++#define aaudio_send(a, ctx, tout, fn, ...) aaudio_send_with_tag(a, ctx, NULL, tout, fn, ##__VA_ARGS__) ++ ++#define aaudio_send_cmd_sync(a, ctx, reply, tout, fn, ...) ({ \ ++ (ctx)->timeout = msecs_to_jiffies(tout); \ ++ (ctx)->status = __aaudio_send_prepare(&(a)->bcem, (ctx), NULL); \ ++ if (!(ctx)->status) { \ ++ fn(&(ctx)->msg, ##__VA_ARGS__); \ ++ (ctx)->status = __aaudio_send_cmd_sync(&(a)->bcem, (ctx), (reply)); \ ++ } \ ++ (ctx)->status; \ ++}) ++ ++struct aaudio_msg aaudio_reply_alloc(void); ++void aaudio_reply_free(struct aaudio_msg *reply); ++ ++#endif //AAUDIO_PROTOCOL_BCE_H +diff --git a/drivers/staging/apple-bce/mailbox.c b/drivers/staging/apple-bce/mailbox.c +new file mode 100644 +index 000000000000..e24bd35215c0 +--- /dev/null ++++ b/drivers/staging/apple-bce/mailbox.c +@@ -0,0 +1,151 @@ ++#include "mailbox.h" ++#include ++#include "apple_bce.h" ++ ++#define REG_MBOX_OUT_BASE 0x820 ++#define REG_MBOX_REPLY_COUNTER 0x108 ++#define REG_MBOX_REPLY_BASE 0x810 ++#define REG_TIMESTAMP_BASE 0xC000 ++ ++#define BCE_MBOX_TIMEOUT_MS 200 ++ ++void bce_mailbox_init(struct bce_mailbox *mb, void __iomem *reg_mb) ++{ ++ mb->reg_mb = reg_mb; ++ init_completion(&mb->mb_completion); ++} ++ ++int bce_mailbox_send(struct bce_mailbox *mb, u64 msg, u64* recv) ++{ ++ u32 __iomem *regb; ++ ++ if (atomic_cmpxchg(&mb->mb_status, 0, 1) != 0) { ++ return -EEXIST; // We don't support two messages at once ++ } ++ reinit_completion(&mb->mb_completion); ++ ++ pr_debug("bce_mailbox_send: %llx\n", msg); ++ regb = (u32*) ((u8*) mb->reg_mb + REG_MBOX_OUT_BASE); ++ iowrite32((u32) msg, regb); ++ iowrite32((u32) (msg >> 32), regb + 1); ++ iowrite32(0, regb + 2); ++ iowrite32(0, regb + 3); ++ ++ wait_for_completion_timeout(&mb->mb_completion, msecs_to_jiffies(BCE_MBOX_TIMEOUT_MS)); ++ if (atomic_read(&mb->mb_status) != 2) { // Didn't get the reply ++ atomic_set(&mb->mb_status, 0); ++ return -ETIMEDOUT; ++ } ++ ++ *recv = mb->mb_result; ++ pr_debug("bce_mailbox_send: reply %llx\n", *recv); ++ ++ atomic_set(&mb->mb_status, 0); ++ return 0; ++} ++ ++static int bce_mailbox_retrive_response(struct bce_mailbox *mb) ++{ ++ u32 __iomem *regb; ++ u32 lo, hi; ++ int count, counter; ++ u32 res = ioread32((u8*) mb->reg_mb + REG_MBOX_REPLY_COUNTER); ++ count = (res >> 20) & 0xf; ++ counter = count; ++ pr_debug("bce_mailbox_retrive_response count=%i\n", count); ++ while (counter--) { ++ regb = (u32*) ((u8*) mb->reg_mb + REG_MBOX_REPLY_BASE); ++ lo = ioread32(regb); ++ hi = ioread32(regb + 1); ++ ioread32(regb + 2); ++ ioread32(regb + 3); ++ pr_debug("bce_mailbox_retrive_response %llx\n", ((u64) hi << 32) | lo); ++ mb->mb_result = ((u64) hi << 32) | lo; ++ } ++ return count > 0 ? 0 : -ENODATA; ++} ++ ++int bce_mailbox_handle_interrupt(struct bce_mailbox *mb) ++{ ++ int status = bce_mailbox_retrive_response(mb); ++ if (!status) { ++ atomic_set(&mb->mb_status, 2); ++ complete(&mb->mb_completion); ++ } ++ return status; ++} ++ ++static void bc_send_timestamp(struct timer_list *tl); ++ ++void bce_timestamp_init(struct bce_timestamp *ts, void __iomem *reg) ++{ ++ u32 __iomem *regb; ++ ++ spin_lock_init(&ts->stop_sl); ++ ts->stopped = false; ++ ++ ts->reg = reg; ++ ++ regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); ++ ++ ioread32(regb); ++ mb(); ++ ++ timer_setup(&ts->timer, bc_send_timestamp, 0); ++} ++ ++void bce_timestamp_start(struct bce_timestamp *ts, bool is_initial) ++{ ++ unsigned long flags; ++ u32 __iomem *regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); ++ ++ if (is_initial) { ++ iowrite32((u32) -4, regb + 2); ++ iowrite32((u32) -1, regb); ++ } else { ++ iowrite32((u32) -3, regb + 2); ++ iowrite32((u32) -1, regb); ++ } ++ ++ spin_lock_irqsave(&ts->stop_sl, flags); ++ ts->stopped = false; ++ spin_unlock_irqrestore(&ts->stop_sl, flags); ++ mod_timer(&ts->timer, jiffies + msecs_to_jiffies(150)); ++} ++ ++void bce_timestamp_stop(struct bce_timestamp *ts) ++{ ++ unsigned long flags; ++ u32 __iomem *regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); ++ ++ spin_lock_irqsave(&ts->stop_sl, flags); ++ ts->stopped = true; ++ spin_unlock_irqrestore(&ts->stop_sl, flags); ++ del_timer_sync(&ts->timer); ++ ++ iowrite32((u32) -2, regb + 2); ++ iowrite32((u32) -1, regb); ++} ++ ++static void bc_send_timestamp(struct timer_list *tl) ++{ ++ struct bce_timestamp *ts; ++ unsigned long flags; ++ u32 __iomem *regb; ++ ktime_t bt; ++ ++ ts = container_of(tl, struct bce_timestamp, timer); ++ regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); ++ local_irq_save(flags); ++ ioread32(regb + 2); ++ mb(); ++ bt = ktime_get_boottime(); ++ iowrite32((u32) bt, regb + 2); ++ iowrite32((u32) (bt >> 32), regb); ++ ++ spin_lock(&ts->stop_sl); ++ if (!ts->stopped) ++ mod_timer(&ts->timer, jiffies + msecs_to_jiffies(150)); ++ spin_unlock(&ts->stop_sl); ++ local_irq_restore(flags); ++} +\ No newline at end of file +diff --git a/drivers/staging/apple-bce/mailbox.h b/drivers/staging/apple-bce/mailbox.h +new file mode 100644 +index 000000000000..f3323f95ba51 +--- /dev/null ++++ b/drivers/staging/apple-bce/mailbox.h +@@ -0,0 +1,53 @@ ++#ifndef BCE_MAILBOX_H ++#define BCE_MAILBOX_H ++ ++#include ++#include ++#include ++ ++struct bce_mailbox { ++ void __iomem *reg_mb; ++ ++ atomic_t mb_status; // possible statuses: 0 (no msg), 1 (has active msg), 2 (got reply) ++ struct completion mb_completion; ++ uint64_t mb_result; ++}; ++ ++enum bce_message_type { ++ BCE_MB_REGISTER_COMMAND_SQ = 0x7, // to-device ++ BCE_MB_REGISTER_COMMAND_CQ = 0x8, // to-device ++ BCE_MB_REGISTER_COMMAND_QUEUE_REPLY = 0xB, // to-host ++ BCE_MB_SET_FW_PROTOCOL_VERSION = 0xC, // both ++ BCE_MB_SLEEP_NO_STATE = 0x14, // to-device ++ BCE_MB_RESTORE_NO_STATE = 0x15, // to-device ++ BCE_MB_SAVE_STATE_AND_SLEEP = 0x17, // to-device ++ BCE_MB_RESTORE_STATE_AND_WAKE = 0x18, // to-device ++ BCE_MB_SAVE_STATE_AND_SLEEP_FAILURE = 0x19, // from-device ++ BCE_MB_SAVE_RESTORE_STATE_COMPLETE = 0x1A, // from-device ++}; ++ ++#define BCE_MB_MSG(type, value) (((u64) (type) << 58) | ((value) & 0x3FFFFFFFFFFFFFFLL)) ++#define BCE_MB_TYPE(v) ((u32) (v >> 58)) ++#define BCE_MB_VALUE(v) (v & 0x3FFFFFFFFFFFFFFLL) ++ ++void bce_mailbox_init(struct bce_mailbox *mb, void __iomem *reg_mb); ++ ++int bce_mailbox_send(struct bce_mailbox *mb, u64 msg, u64* recv); ++ ++int bce_mailbox_handle_interrupt(struct bce_mailbox *mb); ++ ++ ++struct bce_timestamp { ++ void __iomem *reg; ++ struct timer_list timer; ++ struct spinlock stop_sl; ++ bool stopped; ++}; ++ ++void bce_timestamp_init(struct bce_timestamp *ts, void __iomem *reg); ++ ++void bce_timestamp_start(struct bce_timestamp *ts, bool is_initial); ++ ++void bce_timestamp_stop(struct bce_timestamp *ts); ++ ++#endif //BCEDRIVER_MAILBOX_H +diff --git a/drivers/staging/apple-bce/queue.c b/drivers/staging/apple-bce/queue.c +new file mode 100644 +index 000000000000..bc9cd3bc6f0c +--- /dev/null ++++ b/drivers/staging/apple-bce/queue.c +@@ -0,0 +1,390 @@ ++#include "queue.h" ++#include "apple_bce.h" ++ ++#define REG_DOORBELL_BASE 0x44000 ++ ++struct bce_queue_cq *bce_alloc_cq(struct apple_bce_device *dev, int qid, u32 el_count) ++{ ++ struct bce_queue_cq *q; ++ q = kzalloc(sizeof(struct bce_queue_cq), GFP_KERNEL); ++ q->qid = qid; ++ q->type = BCE_QUEUE_CQ; ++ q->el_count = el_count; ++ q->data = dma_alloc_coherent(&dev->pci->dev, el_count * sizeof(struct bce_qe_completion), ++ &q->dma_handle, GFP_KERNEL); ++ if (!q->data) { ++ pr_err("DMA queue memory alloc failed\n"); ++ kfree(q); ++ return NULL; ++ } ++ return q; ++} ++ ++void bce_get_cq_memcfg(struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg) ++{ ++ cfg->qid = (u16) cq->qid; ++ cfg->el_count = (u16) cq->el_count; ++ cfg->vector_or_cq = 0; ++ cfg->_pad = 0; ++ cfg->addr = cq->dma_handle; ++ cfg->length = cq->el_count * sizeof(struct bce_qe_completion); ++} ++ ++void bce_free_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq) ++{ ++ dma_free_coherent(&dev->pci->dev, cq->el_count * sizeof(struct bce_qe_completion), cq->data, cq->dma_handle); ++ kfree(cq); ++} ++ ++static void bce_handle_cq_completion(struct apple_bce_device *dev, struct bce_qe_completion *e, size_t *ce) ++{ ++ struct bce_queue *target; ++ struct bce_queue_sq *target_sq; ++ struct bce_sq_completion_data *cmpl; ++ if (e->qid >= BCE_MAX_QUEUE_COUNT) { ++ pr_err("Device sent a response for qid (%u) >= BCE_MAX_QUEUE_COUNT\n", e->qid); ++ return; ++ } ++ target = dev->queues[e->qid]; ++ if (!target || target->type != BCE_QUEUE_SQ) { ++ pr_err("Device sent a response for qid (%u), which does not exist\n", e->qid); ++ return; ++ } ++ target_sq = (struct bce_queue_sq *) target; ++ if (target_sq->completion_tail != e->completion_index) { ++ pr_err("Completion index mismatch; this is likely going to make this driver unusable\n"); ++ return; ++ } ++ if (!target_sq->has_pending_completions) { ++ target_sq->has_pending_completions = true; ++ dev->int_sq_list[(*ce)++] = target_sq; ++ } ++ cmpl = &target_sq->completion_data[e->completion_index]; ++ cmpl->status = e->status; ++ cmpl->data_size = e->data_size; ++ cmpl->result = e->result; ++ wmb(); ++ target_sq->completion_tail = (target_sq->completion_tail + 1) % target_sq->el_count; ++} ++ ++void bce_handle_cq_completions(struct apple_bce_device *dev, struct bce_queue_cq *cq) ++{ ++ size_t ce = 0; ++ struct bce_qe_completion *e; ++ struct bce_queue_sq *sq; ++ e = bce_cq_element(cq, cq->index); ++ if (!(e->flags & BCE_COMPLETION_FLAG_PENDING)) ++ return; ++ mb(); ++ while (true) { ++ e = bce_cq_element(cq, cq->index); ++ if (!(e->flags & BCE_COMPLETION_FLAG_PENDING)) ++ break; ++ // pr_info("apple-bce: compl: %i: %i %llx %llx", e->qid, e->status, e->data_size, e->result); ++ bce_handle_cq_completion(dev, e, &ce); ++ e->flags = 0; ++ cq->index = (cq->index + 1) % cq->el_count; ++ } ++ mb(); ++ iowrite32(cq->index, (u32 *) ((u8 *) dev->reg_mem_dma + REG_DOORBELL_BASE) + cq->qid); ++ while (ce) { ++ --ce; ++ sq = dev->int_sq_list[ce]; ++ sq->completion(sq); ++ sq->has_pending_completions = false; ++ } ++} ++ ++ ++struct bce_queue_sq *bce_alloc_sq(struct apple_bce_device *dev, int qid, u32 el_size, u32 el_count, ++ bce_sq_completion compl, void *userdata) ++{ ++ struct bce_queue_sq *q; ++ q = kzalloc(sizeof(struct bce_queue_sq), GFP_KERNEL); ++ q->qid = qid; ++ q->type = BCE_QUEUE_SQ; ++ q->el_size = el_size; ++ q->el_count = el_count; ++ q->data = dma_alloc_coherent(&dev->pci->dev, el_count * el_size, ++ &q->dma_handle, GFP_KERNEL); ++ q->completion = compl; ++ q->userdata = userdata; ++ q->completion_data = kzalloc(sizeof(struct bce_sq_completion_data) * el_count, GFP_KERNEL); ++ q->reg_mem_dma = dev->reg_mem_dma; ++ atomic_set(&q->available_commands, el_count - 1); ++ init_completion(&q->available_command_completion); ++ atomic_set(&q->available_command_completion_waiting_count, 0); ++ if (!q->data) { ++ pr_err("DMA queue memory alloc failed\n"); ++ kfree(q); ++ return NULL; ++ } ++ return q; ++} ++ ++void bce_get_sq_memcfg(struct bce_queue_sq *sq, struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg) ++{ ++ cfg->qid = (u16) sq->qid; ++ cfg->el_count = (u16) sq->el_count; ++ cfg->vector_or_cq = (u16) cq->qid; ++ cfg->_pad = 0; ++ cfg->addr = sq->dma_handle; ++ cfg->length = sq->el_count * sq->el_size; ++} ++ ++void bce_free_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq) ++{ ++ dma_free_coherent(&dev->pci->dev, sq->el_count * sq->el_size, sq->data, sq->dma_handle); ++ kfree(sq); ++} ++ ++int bce_reserve_submission(struct bce_queue_sq *sq, unsigned long *timeout) ++{ ++ while (atomic_dec_if_positive(&sq->available_commands) < 0) { ++ if (!timeout || !*timeout) ++ return -EAGAIN; ++ atomic_inc(&sq->available_command_completion_waiting_count); ++ *timeout = wait_for_completion_timeout(&sq->available_command_completion, *timeout); ++ if (!*timeout) { ++ if (atomic_dec_if_positive(&sq->available_command_completion_waiting_count) < 0) ++ try_wait_for_completion(&sq->available_command_completion); /* consume the pending completion */ ++ } ++ } ++ return 0; ++} ++ ++void bce_cancel_submission_reservation(struct bce_queue_sq *sq) ++{ ++ atomic_inc(&sq->available_commands); ++} ++ ++void *bce_next_submission(struct bce_queue_sq *sq) ++{ ++ void *ret = bce_sq_element(sq, sq->tail); ++ sq->tail = (sq->tail + 1) % sq->el_count; ++ return ret; ++} ++ ++void bce_submit_to_device(struct bce_queue_sq *sq) ++{ ++ mb(); ++ iowrite32(sq->tail, (u32 *) ((u8 *) sq->reg_mem_dma + REG_DOORBELL_BASE) + sq->qid); ++} ++ ++void bce_notify_submission_complete(struct bce_queue_sq *sq) ++{ ++ sq->head = (sq->head + 1) % sq->el_count; ++ atomic_inc(&sq->available_commands); ++ if (atomic_dec_if_positive(&sq->available_command_completion_waiting_count) >= 0) { ++ complete(&sq->available_command_completion); ++ } ++} ++ ++void bce_set_submission_single(struct bce_qe_submission *element, dma_addr_t addr, size_t size) ++{ ++ element->addr = addr; ++ element->length = size; ++ element->segl_addr = element->segl_length = 0; ++} ++ ++static void bce_cmdq_completion(struct bce_queue_sq *q); ++ ++struct bce_queue_cmdq *bce_alloc_cmdq(struct apple_bce_device *dev, int qid, u32 el_count) ++{ ++ struct bce_queue_cmdq *q; ++ q = kzalloc(sizeof(struct bce_queue_cmdq), GFP_KERNEL); ++ q->sq = bce_alloc_sq(dev, qid, BCE_CMD_SIZE, el_count, bce_cmdq_completion, q); ++ if (!q->sq) { ++ kfree(q); ++ return NULL; ++ } ++ spin_lock_init(&q->lck); ++ q->tres = kzalloc(sizeof(struct bce_queue_cmdq_result_el*) * el_count, GFP_KERNEL); ++ if (!q->tres) { ++ kfree(q); ++ return NULL; ++ } ++ return q; ++} ++ ++void bce_free_cmdq(struct apple_bce_device *dev, struct bce_queue_cmdq *cmdq) ++{ ++ bce_free_sq(dev, cmdq->sq); ++ kfree(cmdq->tres); ++ kfree(cmdq); ++} ++ ++void bce_cmdq_completion(struct bce_queue_sq *q) ++{ ++ struct bce_queue_cmdq_result_el *el; ++ struct bce_queue_cmdq *cmdq = q->userdata; ++ struct bce_sq_completion_data *result; ++ ++ spin_lock(&cmdq->lck); ++ while ((result = bce_next_completion(q))) { ++ el = cmdq->tres[cmdq->sq->head]; ++ if (el) { ++ el->result = result->result; ++ el->status = result->status; ++ mb(); ++ complete(&el->cmpl); ++ } else { ++ pr_err("apple-bce: Unexpected command queue completion\n"); ++ } ++ cmdq->tres[cmdq->sq->head] = NULL; ++ bce_notify_submission_complete(q); ++ } ++ spin_unlock(&cmdq->lck); ++} ++ ++static __always_inline void *bce_cmd_start(struct bce_queue_cmdq *cmdq, struct bce_queue_cmdq_result_el *res) ++{ ++ void *ret; ++ unsigned long timeout; ++ init_completion(&res->cmpl); ++ mb(); ++ ++ timeout = msecs_to_jiffies(1000L * 60 * 5); /* wait for up to ~5 minutes */ ++ if (bce_reserve_submission(cmdq->sq, &timeout)) ++ return NULL; ++ ++ spin_lock(&cmdq->lck); ++ cmdq->tres[cmdq->sq->tail] = res; ++ ret = bce_next_submission(cmdq->sq); ++ return ret; ++} ++ ++static __always_inline void bce_cmd_finish(struct bce_queue_cmdq *cmdq, struct bce_queue_cmdq_result_el *res) ++{ ++ bce_submit_to_device(cmdq->sq); ++ spin_unlock(&cmdq->lck); ++ ++ wait_for_completion(&res->cmpl); ++ mb(); ++} ++ ++u32 bce_cmd_register_queue(struct bce_queue_cmdq *cmdq, struct bce_queue_memcfg *cfg, const char *name, bool isdirout) ++{ ++ struct bce_queue_cmdq_result_el res; ++ struct bce_cmdq_register_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); ++ if (!cmd) ++ return (u32) -1; ++ cmd->cmd = BCE_CMD_REGISTER_MEMORY_QUEUE; ++ cmd->flags = (u16) ((name ? 2 : 0) | (isdirout ? 1 : 0)); ++ cmd->qid = cfg->qid; ++ cmd->el_count = cfg->el_count; ++ cmd->vector_or_cq = cfg->vector_or_cq; ++ memset(cmd->name, 0, sizeof(cmd->name)); ++ if (name) { ++ cmd->name_len = (u16) min(strlen(name), (size_t) sizeof(cmd->name)); ++ memcpy(cmd->name, name, cmd->name_len); ++ } else { ++ cmd->name_len = 0; ++ } ++ cmd->addr = cfg->addr; ++ cmd->length = cfg->length; ++ ++ bce_cmd_finish(cmdq, &res); ++ return res.status; ++} ++ ++u32 bce_cmd_unregister_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid) ++{ ++ struct bce_queue_cmdq_result_el res; ++ struct bce_cmdq_simple_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); ++ if (!cmd) ++ return (u32) -1; ++ cmd->cmd = BCE_CMD_UNREGISTER_MEMORY_QUEUE; ++ cmd->flags = 0; ++ cmd->qid = qid; ++ bce_cmd_finish(cmdq, &res); ++ return res.status; ++} ++ ++u32 bce_cmd_flush_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid) ++{ ++ struct bce_queue_cmdq_result_el res; ++ struct bce_cmdq_simple_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); ++ if (!cmd) ++ return (u32) -1; ++ cmd->cmd = BCE_CMD_FLUSH_MEMORY_QUEUE; ++ cmd->flags = 0; ++ cmd->qid = qid; ++ bce_cmd_finish(cmdq, &res); ++ return res.status; ++} ++ ++ ++struct bce_queue_cq *bce_create_cq(struct apple_bce_device *dev, u32 el_count) ++{ ++ struct bce_queue_cq *cq; ++ struct bce_queue_memcfg cfg; ++ int qid = ida_simple_get(&dev->queue_ida, BCE_QUEUE_USER_MIN, BCE_QUEUE_USER_MAX, GFP_KERNEL); ++ if (qid < 0) ++ return NULL; ++ cq = bce_alloc_cq(dev, qid, el_count); ++ if (!cq) ++ return NULL; ++ bce_get_cq_memcfg(cq, &cfg); ++ if (bce_cmd_register_queue(dev->cmd_cmdq, &cfg, NULL, false) != 0) { ++ pr_err("apple-bce: CQ registration failed (%i)", qid); ++ bce_free_cq(dev, cq); ++ ida_simple_remove(&dev->queue_ida, (uint) qid); ++ return NULL; ++ } ++ dev->queues[qid] = (struct bce_queue *) cq; ++ return cq; ++} ++ ++struct bce_queue_sq *bce_create_sq(struct apple_bce_device *dev, struct bce_queue_cq *cq, const char *name, u32 el_count, ++ int direction, bce_sq_completion compl, void *userdata) ++{ ++ struct bce_queue_sq *sq; ++ struct bce_queue_memcfg cfg; ++ int qid; ++ if (cq == NULL) ++ return NULL; /* cq can not be null */ ++ if (name == NULL) ++ return NULL; /* name can not be null */ ++ if (direction != DMA_TO_DEVICE && direction != DMA_FROM_DEVICE) ++ return NULL; /* unsupported direction */ ++ qid = ida_simple_get(&dev->queue_ida, BCE_QUEUE_USER_MIN, BCE_QUEUE_USER_MAX, GFP_KERNEL); ++ if (qid < 0) ++ return NULL; ++ sq = bce_alloc_sq(dev, qid, sizeof(struct bce_qe_submission), el_count, compl, userdata); ++ if (!sq) ++ return NULL; ++ bce_get_sq_memcfg(sq, cq, &cfg); ++ if (bce_cmd_register_queue(dev->cmd_cmdq, &cfg, name, direction != DMA_FROM_DEVICE) != 0) { ++ pr_err("apple-bce: SQ registration failed (%i)", qid); ++ bce_free_sq(dev, sq); ++ ida_simple_remove(&dev->queue_ida, (uint) qid); ++ return NULL; ++ } ++ spin_lock(&dev->queues_lock); ++ dev->queues[qid] = (struct bce_queue *) sq; ++ spin_unlock(&dev->queues_lock); ++ return sq; ++} ++ ++void bce_destroy_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq) ++{ ++ if (!dev->is_being_removed && bce_cmd_unregister_memory_queue(dev->cmd_cmdq, (u16) cq->qid)) ++ pr_err("apple-bce: CQ unregister failed"); ++ spin_lock(&dev->queues_lock); ++ dev->queues[cq->qid] = NULL; ++ spin_unlock(&dev->queues_lock); ++ ida_simple_remove(&dev->queue_ida, (uint) cq->qid); ++ bce_free_cq(dev, cq); ++} ++ ++void bce_destroy_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq) ++{ ++ if (!dev->is_being_removed && bce_cmd_unregister_memory_queue(dev->cmd_cmdq, (u16) sq->qid)) ++ pr_err("apple-bce: CQ unregister failed"); ++ spin_lock(&dev->queues_lock); ++ dev->queues[sq->qid] = NULL; ++ spin_unlock(&dev->queues_lock); ++ ida_simple_remove(&dev->queue_ida, (uint) sq->qid); ++ bce_free_sq(dev, sq); ++} +\ No newline at end of file +diff --git a/drivers/staging/apple-bce/queue.h b/drivers/staging/apple-bce/queue.h +new file mode 100644 +index 000000000000..8368ac5dfca8 +--- /dev/null ++++ b/drivers/staging/apple-bce/queue.h +@@ -0,0 +1,177 @@ ++#ifndef BCE_QUEUE_H ++#define BCE_QUEUE_H ++ ++#include ++#include ++ ++#define BCE_CMD_SIZE 0x40 ++ ++struct apple_bce_device; ++ ++enum bce_queue_type { ++ BCE_QUEUE_CQ, BCE_QUEUE_SQ ++}; ++struct bce_queue { ++ int qid; ++ int type; ++}; ++struct bce_queue_cq { ++ int qid; ++ int type; ++ u32 el_count; ++ dma_addr_t dma_handle; ++ void *data; ++ ++ u32 index; ++}; ++struct bce_queue_sq; ++typedef void (*bce_sq_completion)(struct bce_queue_sq *q); ++struct bce_sq_completion_data { ++ u32 status; ++ u64 data_size; ++ u64 result; ++}; ++struct bce_queue_sq { ++ int qid; ++ int type; ++ u32 el_size; ++ u32 el_count; ++ dma_addr_t dma_handle; ++ void *data; ++ void *userdata; ++ void __iomem *reg_mem_dma; ++ ++ atomic_t available_commands; ++ struct completion available_command_completion; ++ atomic_t available_command_completion_waiting_count; ++ u32 head, tail; ++ ++ u32 completion_cidx, completion_tail; ++ struct bce_sq_completion_data *completion_data; ++ bool has_pending_completions; ++ bce_sq_completion completion; ++}; ++ ++struct bce_queue_cmdq_result_el { ++ struct completion cmpl; ++ u32 status; ++ u64 result; ++}; ++struct bce_queue_cmdq { ++ struct bce_queue_sq *sq; ++ struct spinlock lck; ++ struct bce_queue_cmdq_result_el **tres; ++}; ++ ++struct bce_queue_memcfg { ++ u16 qid; ++ u16 el_count; ++ u16 vector_or_cq; ++ u16 _pad; ++ u64 addr; ++ u64 length; ++}; ++ ++enum bce_qe_completion_status { ++ BCE_COMPLETION_SUCCESS = 0, ++ BCE_COMPLETION_ERROR = 1, ++ BCE_COMPLETION_ABORTED = 2, ++ BCE_COMPLETION_NO_SPACE = 3, ++ BCE_COMPLETION_OVERRUN = 4 ++}; ++enum bce_qe_completion_flags { ++ BCE_COMPLETION_FLAG_PENDING = 0x8000 ++}; ++struct bce_qe_completion { ++ u64 result; ++ u64 data_size; ++ u16 qid; ++ u16 completion_index; ++ u16 status; // bce_qe_completion_status ++ u16 flags; // bce_qe_completion_flags ++}; ++ ++struct bce_qe_submission { ++ u64 length; ++ u64 addr; ++ ++ u64 segl_addr; ++ u64 segl_length; ++}; ++ ++enum bce_cmdq_command { ++ BCE_CMD_REGISTER_MEMORY_QUEUE = 0x20, ++ BCE_CMD_UNREGISTER_MEMORY_QUEUE = 0x30, ++ BCE_CMD_FLUSH_MEMORY_QUEUE = 0x40, ++ BCE_CMD_SET_MEMORY_QUEUE_PROPERTY = 0x50 ++}; ++struct bce_cmdq_simple_memory_queue_cmd { ++ u16 cmd; // bce_cmdq_command ++ u16 flags; ++ u16 qid; ++}; ++struct bce_cmdq_register_memory_queue_cmd { ++ u16 cmd; // bce_cmdq_command ++ u16 flags; ++ u16 qid; ++ u16 _pad; ++ u16 el_count; ++ u16 vector_or_cq; ++ u16 _pad2; ++ u16 name_len; ++ char name[0x20]; ++ u64 addr; ++ u64 length; ++}; ++ ++static __always_inline void *bce_sq_element(struct bce_queue_sq *q, int i) { ++ return (void *) ((u8 *) q->data + q->el_size * i); ++} ++static __always_inline void *bce_cq_element(struct bce_queue_cq *q, int i) { ++ return (void *) ((struct bce_qe_completion *) q->data + i); ++} ++ ++static __always_inline struct bce_sq_completion_data *bce_next_completion(struct bce_queue_sq *sq) { ++ struct bce_sq_completion_data *res; ++ rmb(); ++ if (sq->completion_cidx == sq->completion_tail) ++ return NULL; ++ res = &sq->completion_data[sq->completion_cidx]; ++ sq->completion_cidx = (sq->completion_cidx + 1) % sq->el_count; ++ return res; ++} ++ ++struct bce_queue_cq *bce_alloc_cq(struct apple_bce_device *dev, int qid, u32 el_count); ++void bce_get_cq_memcfg(struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg); ++void bce_free_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq); ++void bce_handle_cq_completions(struct apple_bce_device *dev, struct bce_queue_cq *cq); ++ ++struct bce_queue_sq *bce_alloc_sq(struct apple_bce_device *dev, int qid, u32 el_size, u32 el_count, ++ bce_sq_completion compl, void *userdata); ++void bce_get_sq_memcfg(struct bce_queue_sq *sq, struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg); ++void bce_free_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq); ++int bce_reserve_submission(struct bce_queue_sq *sq, unsigned long *timeout); ++void bce_cancel_submission_reservation(struct bce_queue_sq *sq); ++void *bce_next_submission(struct bce_queue_sq *sq); ++void bce_submit_to_device(struct bce_queue_sq *sq); ++void bce_notify_submission_complete(struct bce_queue_sq *sq); ++ ++void bce_set_submission_single(struct bce_qe_submission *element, dma_addr_t addr, size_t size); ++ ++struct bce_queue_cmdq *bce_alloc_cmdq(struct apple_bce_device *dev, int qid, u32 el_count); ++void bce_free_cmdq(struct apple_bce_device *dev, struct bce_queue_cmdq *cmdq); ++ ++u32 bce_cmd_register_queue(struct bce_queue_cmdq *cmdq, struct bce_queue_memcfg *cfg, const char *name, bool isdirout); ++u32 bce_cmd_unregister_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid); ++u32 bce_cmd_flush_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid); ++ ++ ++/* User API - Creates and registers the queue */ ++ ++struct bce_queue_cq *bce_create_cq(struct apple_bce_device *dev, u32 el_count); ++struct bce_queue_sq *bce_create_sq(struct apple_bce_device *dev, struct bce_queue_cq *cq, const char *name, u32 el_count, ++ int direction, bce_sq_completion compl, void *userdata); ++void bce_destroy_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq); ++void bce_destroy_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq); ++ ++#endif //BCEDRIVER_MAILBOX_H +diff --git a/drivers/staging/apple-bce/queue_dma.c b/drivers/staging/apple-bce/queue_dma.c +new file mode 100644 +index 000000000000..b236613285c0 +--- /dev/null ++++ b/drivers/staging/apple-bce/queue_dma.c +@@ -0,0 +1,220 @@ ++#include "queue_dma.h" ++#include ++#include ++#include "queue.h" ++ ++static int bce_alloc_scatterlist_from_vm(struct sg_table *tbl, void *data, size_t len); ++static struct bce_segment_list_element_hostinfo *bce_map_segment_list( ++ struct device *dev, struct scatterlist *pages, int pagen); ++static void bce_unmap_segement_list(struct device *dev, struct bce_segment_list_element_hostinfo *list); ++ ++int bce_map_dma_buffer(struct device *dev, struct bce_dma_buffer *buf, struct sg_table scatterlist, ++ enum dma_data_direction dir) ++{ ++ int cnt; ++ ++ buf->direction = dir; ++ buf->scatterlist = scatterlist; ++ buf->seglist_hostinfo = NULL; ++ ++ cnt = dma_map_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); ++ if (cnt != buf->scatterlist.nents) { ++ pr_err("apple-bce: DMA scatter list mapping returned an unexpected count: %i\n", cnt); ++ dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); ++ return -EIO; ++ } ++ if (cnt == 1) ++ return 0; ++ ++ buf->seglist_hostinfo = bce_map_segment_list(dev, buf->scatterlist.sgl, buf->scatterlist.nents); ++ if (!buf->seglist_hostinfo) { ++ pr_err("apple-bce: Creating segment list failed\n"); ++ dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); ++ return -EIO; ++ } ++ return 0; ++} ++ ++int bce_map_dma_buffer_vm(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, ++ enum dma_data_direction dir) ++{ ++ int status; ++ struct sg_table scatterlist; ++ if ((status = bce_alloc_scatterlist_from_vm(&scatterlist, data, len))) ++ return status; ++ if ((status = bce_map_dma_buffer(dev, buf, scatterlist, dir))) { ++ sg_free_table(&scatterlist); ++ return status; ++ } ++ return 0; ++} ++ ++int bce_map_dma_buffer_km(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, ++ enum dma_data_direction dir) ++{ ++ /* Kernel memory is continuous which is great for us. */ ++ int status; ++ struct sg_table scatterlist; ++ if ((status = sg_alloc_table(&scatterlist, 1, GFP_KERNEL))) { ++ sg_free_table(&scatterlist); ++ return status; ++ } ++ sg_set_buf(scatterlist.sgl, data, (uint) len); ++ if ((status = bce_map_dma_buffer(dev, buf, scatterlist, dir))) { ++ sg_free_table(&scatterlist); ++ return status; ++ } ++ return 0; ++} ++ ++void bce_unmap_dma_buffer(struct device *dev, struct bce_dma_buffer *buf) ++{ ++ dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, buf->direction); ++ bce_unmap_segement_list(dev, buf->seglist_hostinfo); ++} ++ ++ ++static int bce_alloc_scatterlist_from_vm(struct sg_table *tbl, void *data, size_t len) ++{ ++ int status, i; ++ struct page **pages; ++ size_t off, start_page, end_page, page_count; ++ off = (size_t) data % PAGE_SIZE; ++ start_page = (size_t) data / PAGE_SIZE; ++ end_page = ((size_t) data + len - 1) / PAGE_SIZE; ++ page_count = end_page - start_page + 1; ++ ++ if (page_count > PAGE_SIZE / sizeof(struct page *)) ++ pages = vmalloc(page_count * sizeof(struct page *)); ++ else ++ pages = kmalloc(page_count * sizeof(struct page *), GFP_KERNEL); ++ ++ for (i = 0; i < page_count; i++) ++ pages[i] = vmalloc_to_page((void *) ((start_page + i) * PAGE_SIZE)); ++ ++ if ((status = sg_alloc_table_from_pages(tbl, pages, page_count, (unsigned int) off, len, GFP_KERNEL))) { ++ sg_free_table(tbl); ++ } ++ ++ if (page_count > PAGE_SIZE / sizeof(struct page *)) ++ vfree(pages); ++ else ++ kfree(pages); ++ return status; ++} ++ ++#define BCE_ELEMENTS_PER_PAGE ((PAGE_SIZE - sizeof(struct bce_segment_list_header)) \ ++ / sizeof(struct bce_segment_list_element)) ++#define BCE_ELEMENTS_PER_ADDITIONAL_PAGE (PAGE_SIZE / sizeof(struct bce_segment_list_element)) ++ ++static struct bce_segment_list_element_hostinfo *bce_map_segment_list( ++ struct device *dev, struct scatterlist *pages, int pagen) ++{ ++ size_t ptr, pptr = 0; ++ struct bce_segment_list_header theader; /* a temp header, to store the initial seg */ ++ struct bce_segment_list_header *header; ++ struct bce_segment_list_element *el, *el_end; ++ struct bce_segment_list_element_hostinfo *out, *pout, *out_root; ++ struct scatterlist *sg; ++ int i; ++ header = &theader; ++ out = out_root = NULL; ++ el = el_end = NULL; ++ for_each_sg(pages, sg, pagen, i) { ++ if (el >= el_end) { ++ /* allocate a new page, this will be also done for the first element */ ++ ptr = __get_free_page(GFP_KERNEL); ++ if (pptr && ptr == pptr + PAGE_SIZE) { ++ out->page_count++; ++ header->element_count += BCE_ELEMENTS_PER_ADDITIONAL_PAGE; ++ el_end += BCE_ELEMENTS_PER_ADDITIONAL_PAGE; ++ } else { ++ header = (void *) ptr; ++ header->element_count = BCE_ELEMENTS_PER_PAGE; ++ header->data_size = 0; ++ header->next_segl_addr = 0; ++ header->next_segl_length = 0; ++ el = (void *) (header + 1); ++ el_end = el + BCE_ELEMENTS_PER_PAGE; ++ ++ if (out) { ++ out->next = kmalloc(sizeof(struct bce_segment_list_element_hostinfo), GFP_KERNEL); ++ out = out->next; ++ } else { ++ out_root = out = kmalloc(sizeof(struct bce_segment_list_element_hostinfo), GFP_KERNEL); ++ } ++ out->page_start = (void *) ptr; ++ out->page_count = 1; ++ out->dma_start = DMA_MAPPING_ERROR; ++ out->next = NULL; ++ } ++ pptr = ptr; ++ } ++ el->addr = sg->dma_address; ++ el->length = sg->length; ++ header->data_size += el->length; ++ } ++ ++ /* DMA map */ ++ out = out_root; ++ pout = NULL; ++ while (out) { ++ out->dma_start = dma_map_single(dev, out->page_start, out->page_count * PAGE_SIZE, DMA_TO_DEVICE); ++ if (dma_mapping_error(dev, out->dma_start)) ++ goto error; ++ if (pout) { ++ header = pout->page_start; ++ header->next_segl_addr = out->dma_start; ++ header->next_segl_length = out->page_count * PAGE_SIZE; ++ } ++ pout = out; ++ out = out->next; ++ } ++ return out_root; ++ ++ error: ++ bce_unmap_segement_list(dev, out_root); ++ return NULL; ++} ++ ++static void bce_unmap_segement_list(struct device *dev, struct bce_segment_list_element_hostinfo *list) ++{ ++ struct bce_segment_list_element_hostinfo *next; ++ while (list) { ++ if (list->dma_start != DMA_MAPPING_ERROR) ++ dma_unmap_single(dev, list->dma_start, list->page_count * PAGE_SIZE, DMA_TO_DEVICE); ++ next = list->next; ++ kfree(list); ++ list = next; ++ } ++} ++ ++int bce_set_submission_buf(struct bce_qe_submission *element, struct bce_dma_buffer *buf, size_t offset, size_t length) ++{ ++ struct bce_segment_list_element_hostinfo *seg; ++ struct bce_segment_list_header *seg_header; ++ ++ seg = buf->seglist_hostinfo; ++ if (!seg) { ++ element->addr = buf->scatterlist.sgl->dma_address + offset; ++ element->length = length; ++ element->segl_addr = 0; ++ element->segl_length = 0; ++ return 0; ++ } ++ ++ while (seg) { ++ seg_header = seg->page_start; ++ if (offset <= seg_header->data_size) ++ break; ++ offset -= seg_header->data_size; ++ seg = seg->next; ++ } ++ if (!seg) ++ return -EINVAL; ++ element->addr = offset; ++ element->length = buf->scatterlist.sgl->dma_length; ++ element->segl_addr = seg->dma_start; ++ element->segl_length = seg->page_count * PAGE_SIZE; ++ return 0; ++} +\ No newline at end of file +diff --git a/drivers/staging/apple-bce/queue_dma.h b/drivers/staging/apple-bce/queue_dma.h +new file mode 100644 +index 000000000000..f8a57e50e7a3 +--- /dev/null ++++ b/drivers/staging/apple-bce/queue_dma.h +@@ -0,0 +1,50 @@ ++#ifndef BCE_QUEUE_DMA_H ++#define BCE_QUEUE_DMA_H ++ ++#include ++ ++struct bce_qe_submission; ++ ++struct bce_segment_list_header { ++ u64 element_count; ++ u64 data_size; ++ ++ u64 next_segl_addr; ++ u64 next_segl_length; ++}; ++struct bce_segment_list_element { ++ u64 addr; ++ u64 length; ++}; ++ ++struct bce_segment_list_element_hostinfo { ++ struct bce_segment_list_element_hostinfo *next; ++ void *page_start; ++ size_t page_count; ++ dma_addr_t dma_start; ++}; ++ ++ ++struct bce_dma_buffer { ++ enum dma_data_direction direction; ++ struct sg_table scatterlist; ++ struct bce_segment_list_element_hostinfo *seglist_hostinfo; ++}; ++ ++/* NOTE: Takes ownership of the sg_table if it succeeds. Ownership is not transferred on failure. */ ++int bce_map_dma_buffer(struct device *dev, struct bce_dma_buffer *buf, struct sg_table scatterlist, ++ enum dma_data_direction dir); ++ ++/* Creates a buffer from virtual memory (vmalloc) */ ++int bce_map_dma_buffer_vm(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, ++ enum dma_data_direction dir); ++ ++/* Creates a buffer from kernel memory (kmalloc) */ ++int bce_map_dma_buffer_km(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, ++ enum dma_data_direction dir); ++ ++void bce_unmap_dma_buffer(struct device *dev, struct bce_dma_buffer *buf); ++ ++int bce_set_submission_buf(struct bce_qe_submission *element, struct bce_dma_buffer *buf, size_t offset, size_t length); ++ ++#endif //BCE_QUEUE_DMA_H +diff --git a/drivers/staging/apple-bce/vhci/command.h b/drivers/staging/apple-bce/vhci/command.h +new file mode 100644 +index 000000000000..26619e0bccfa +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/command.h +@@ -0,0 +1,204 @@ ++#ifndef BCE_VHCI_COMMAND_H ++#define BCE_VHCI_COMMAND_H ++ ++#include "queue.h" ++#include ++#include ++ ++#define BCE_VHCI_CMD_TIMEOUT_SHORT msecs_to_jiffies(2000) ++#define BCE_VHCI_CMD_TIMEOUT_LONG msecs_to_jiffies(30000) ++ ++#define BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2 2 ++#define BCE_VHCI_BULK_MAX_ACTIVE_URBS (1 << BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2) ++ ++typedef u8 bce_vhci_port_t; ++typedef u8 bce_vhci_device_t; ++ ++enum bce_vhci_command { ++ BCE_VHCI_CMD_CONTROLLER_ENABLE = 1, ++ BCE_VHCI_CMD_CONTROLLER_DISABLE = 2, ++ BCE_VHCI_CMD_CONTROLLER_START = 3, ++ BCE_VHCI_CMD_CONTROLLER_PAUSE = 4, ++ ++ BCE_VHCI_CMD_PORT_POWER_ON = 0x10, ++ BCE_VHCI_CMD_PORT_POWER_OFF = 0x11, ++ BCE_VHCI_CMD_PORT_RESUME = 0x12, ++ BCE_VHCI_CMD_PORT_SUSPEND = 0x13, ++ BCE_VHCI_CMD_PORT_RESET = 0x14, ++ BCE_VHCI_CMD_PORT_DISABLE = 0x15, ++ BCE_VHCI_CMD_PORT_STATUS = 0x16, ++ ++ BCE_VHCI_CMD_DEVICE_CREATE = 0x30, ++ BCE_VHCI_CMD_DEVICE_DESTROY = 0x31, ++ ++ BCE_VHCI_CMD_ENDPOINT_CREATE = 0x40, ++ BCE_VHCI_CMD_ENDPOINT_DESTROY = 0x41, ++ BCE_VHCI_CMD_ENDPOINT_SET_STATE = 0x42, ++ BCE_VHCI_CMD_ENDPOINT_RESET = 0x44, ++ ++ /* Device to host only */ ++ BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE = 0x43, ++ BCE_VHCI_CMD_TRANSFER_REQUEST = 0x1000, ++ BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS = 0x1005 ++}; ++ ++enum bce_vhci_endpoint_state { ++ BCE_VHCI_ENDPOINT_ACTIVE = 0, ++ BCE_VHCI_ENDPOINT_PAUSED = 1, ++ BCE_VHCI_ENDPOINT_STALLED = 2 ++}; ++ ++static inline int bce_vhci_cmd_controller_enable(struct bce_vhci_command_queue *q, u8 busNum, u16 *portMask) ++{ ++ int status; ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_CONTROLLER_ENABLE; ++ cmd.param1 = 0x7100u | busNum; ++ status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++ if (!status) ++ *portMask = (u16) res.param2; ++ return status; ++} ++static inline int bce_vhci_cmd_controller_disable(struct bce_vhci_command_queue *q) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_CONTROLLER_DISABLE; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++} ++static inline int bce_vhci_cmd_controller_start(struct bce_vhci_command_queue *q) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_CONTROLLER_START; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++} ++static inline int bce_vhci_cmd_controller_pause(struct bce_vhci_command_queue *q) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_CONTROLLER_PAUSE; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++} ++ ++static inline int bce_vhci_cmd_port_power_on(struct bce_vhci_command_queue *q, bce_vhci_port_t port) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_POWER_ON; ++ cmd.param1 = port; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++static inline int bce_vhci_cmd_port_power_off(struct bce_vhci_command_queue *q, bce_vhci_port_t port) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_POWER_OFF; ++ cmd.param1 = port; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++static inline int bce_vhci_cmd_port_resume(struct bce_vhci_command_queue *q, bce_vhci_port_t port) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_RESUME; ++ cmd.param1 = port; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++} ++static inline int bce_vhci_cmd_port_suspend(struct bce_vhci_command_queue *q, bce_vhci_port_t port) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_SUSPEND; ++ cmd.param1 = port; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++} ++static inline int bce_vhci_cmd_port_reset(struct bce_vhci_command_queue *q, bce_vhci_port_t port, u32 timeout) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_RESET; ++ cmd.param1 = port; ++ cmd.param2 = timeout; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++static inline int bce_vhci_cmd_port_disable(struct bce_vhci_command_queue *q, bce_vhci_port_t port) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_DISABLE; ++ cmd.param1 = port; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++static inline int bce_vhci_cmd_port_status(struct bce_vhci_command_queue *q, bce_vhci_port_t port, ++ u32 clearFlags, u32 *resStatus) ++{ ++ int status; ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_PORT_STATUS; ++ cmd.param1 = port; ++ cmd.param2 = clearFlags & 0x560000; ++ status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++ if (status >= 0) ++ *resStatus = (u32) res.param2; ++ return status; ++} ++ ++static inline int bce_vhci_cmd_device_create(struct bce_vhci_command_queue *q, bce_vhci_port_t port, ++ bce_vhci_device_t *dev) ++{ ++ int status; ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_DEVICE_CREATE; ++ cmd.param1 = port; ++ status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++ if (!status) ++ *dev = (bce_vhci_device_t) res.param2; ++ return status; ++} ++static inline int bce_vhci_cmd_device_destroy(struct bce_vhci_command_queue *q, bce_vhci_device_t dev) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_DEVICE_DESTROY; ++ cmd.param1 = dev; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); ++} ++ ++static inline int bce_vhci_cmd_endpoint_create(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, ++ struct usb_endpoint_descriptor *desc) ++{ ++ struct bce_vhci_message cmd, res; ++ int endpoint_type = usb_endpoint_type(desc); ++ int maxp = usb_endpoint_maxp(desc); ++ int maxp_burst = usb_endpoint_maxp_mult(desc) * maxp; ++ u8 max_active_requests_pow2 = 0; ++ cmd.cmd = BCE_VHCI_CMD_ENDPOINT_CREATE; ++ cmd.param1 = dev | ((desc->bEndpointAddress & 0x8Fu) << 8); ++ if (endpoint_type == USB_ENDPOINT_XFER_BULK) ++ max_active_requests_pow2 = BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2; ++ cmd.param2 = endpoint_type | ((max_active_requests_pow2 & 0xf) << 4) | (maxp << 16) | ((u64) maxp_burst << 32); ++ if (endpoint_type == USB_ENDPOINT_XFER_INT) ++ cmd.param2 |= (desc->bInterval - 1) << 8; ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++static inline int bce_vhci_cmd_endpoint_destroy(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_ENDPOINT_DESTROY; ++ cmd.param1 = dev | (endpoint << 8); ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++static inline int bce_vhci_cmd_endpoint_set_state(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint, ++ enum bce_vhci_endpoint_state newState, enum bce_vhci_endpoint_state *retState) ++{ ++ int status; ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_ENDPOINT_SET_STATE; ++ cmd.param1 = dev | (endpoint << 8); ++ cmd.param2 = (u64) newState; ++ status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++ if (status != BCE_VHCI_INTERNAL_ERROR && status != BCE_VHCI_NO_POWER) ++ *retState = (enum bce_vhci_endpoint_state) res.param2; ++ return status; ++} ++static inline int bce_vhci_cmd_endpoint_reset(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint) ++{ ++ struct bce_vhci_message cmd, res; ++ cmd.cmd = BCE_VHCI_CMD_ENDPOINT_RESET; ++ cmd.param1 = dev | (endpoint << 8); ++ return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); ++} ++ ++ ++#endif //BCE_VHCI_COMMAND_H +diff --git a/drivers/staging/apple-bce/vhci/queue.c b/drivers/staging/apple-bce/vhci/queue.c +new file mode 100644 +index 000000000000..7b0b5027157b +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/queue.c +@@ -0,0 +1,268 @@ ++#include "queue.h" ++#include "vhci.h" ++#include "../apple_bce.h" ++ ++ ++static void bce_vhci_message_queue_completion(struct bce_queue_sq *sq); ++ ++int bce_vhci_message_queue_create(struct bce_vhci *vhci, struct bce_vhci_message_queue *ret, const char *name) ++{ ++ int status; ++ ret->cq = bce_create_cq(vhci->dev, VHCI_EVENT_QUEUE_EL_COUNT); ++ if (!ret->cq) ++ return -EINVAL; ++ ret->sq = bce_create_sq(vhci->dev, ret->cq, name, VHCI_EVENT_QUEUE_EL_COUNT, DMA_TO_DEVICE, ++ bce_vhci_message_queue_completion, ret); ++ if (!ret->sq) { ++ status = -EINVAL; ++ goto fail_cq; ++ } ++ ret->data = dma_alloc_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, ++ &ret->dma_addr, GFP_KERNEL); ++ if (!ret->data) { ++ status = -EINVAL; ++ goto fail_sq; ++ } ++ return 0; ++ ++fail_sq: ++ bce_destroy_sq(vhci->dev, ret->sq); ++ ret->sq = NULL; ++fail_cq: ++ bce_destroy_cq(vhci->dev, ret->cq); ++ ret->cq = NULL; ++ return status; ++} ++ ++void bce_vhci_message_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_message_queue *q) ++{ ++ if (!q->cq) ++ return; ++ dma_free_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, ++ q->data, q->dma_addr); ++ bce_destroy_sq(vhci->dev, q->sq); ++ bce_destroy_cq(vhci->dev, q->cq); ++} ++ ++void bce_vhci_message_queue_write(struct bce_vhci_message_queue *q, struct bce_vhci_message *req) ++{ ++ int sidx; ++ struct bce_qe_submission *s; ++ sidx = q->sq->tail; ++ s = bce_next_submission(q->sq); ++ pr_debug("bce-vhci: Send message: %x s=%x p1=%x p2=%llx\n", req->cmd, req->status, req->param1, req->param2); ++ q->data[sidx] = *req; ++ bce_set_submission_single(s, q->dma_addr + sizeof(struct bce_vhci_message) * sidx, ++ sizeof(struct bce_vhci_message)); ++ bce_submit_to_device(q->sq); ++} ++ ++static void bce_vhci_message_queue_completion(struct bce_queue_sq *sq) ++{ ++ while (bce_next_completion(sq)) ++ bce_notify_submission_complete(sq); ++} ++ ++ ++ ++static void bce_vhci_event_queue_completion(struct bce_queue_sq *sq); ++ ++int __bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, ++ bce_sq_completion compl) ++{ ++ ret->vhci = vhci; ++ ++ ret->sq = bce_create_sq(vhci->dev, vhci->ev_cq, name, VHCI_EVENT_QUEUE_EL_COUNT, DMA_FROM_DEVICE, compl, ret); ++ if (!ret->sq) ++ return -EINVAL; ++ ret->data = dma_alloc_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, ++ &ret->dma_addr, GFP_KERNEL); ++ if (!ret->data) { ++ bce_destroy_sq(vhci->dev, ret->sq); ++ ret->sq = NULL; ++ return -EINVAL; ++ } ++ ++ init_completion(&ret->queue_empty_completion); ++ bce_vhci_event_queue_submit_pending(ret, VHCI_EVENT_PENDING_COUNT); ++ return 0; ++} ++ ++int bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, ++ bce_vhci_event_queue_callback cb) ++{ ++ ret->cb = cb; ++ return __bce_vhci_event_queue_create(vhci, ret, name, bce_vhci_event_queue_completion); ++} ++ ++void bce_vhci_event_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_event_queue *q) ++{ ++ if (!q->sq) ++ return; ++ dma_free_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, ++ q->data, q->dma_addr); ++ bce_destroy_sq(vhci->dev, q->sq); ++} ++ ++static void bce_vhci_event_queue_completion(struct bce_queue_sq *sq) ++{ ++ struct bce_sq_completion_data *cd; ++ struct bce_vhci_event_queue *ev = sq->userdata; ++ struct bce_vhci_message *msg; ++ size_t cnt = 0; ++ ++ while ((cd = bce_next_completion(sq))) { ++ if (cd->status == BCE_COMPLETION_ABORTED) { /* We flushed the queue */ ++ bce_notify_submission_complete(sq); ++ continue; ++ } ++ msg = &ev->data[sq->head]; ++ pr_debug("bce-vhci: Got event: %x s=%x p1=%x p2=%llx\n", msg->cmd, msg->status, msg->param1, msg->param2); ++ ev->cb(ev, msg); ++ ++ bce_notify_submission_complete(sq); ++ ++cnt; ++ } ++ bce_vhci_event_queue_submit_pending(ev, cnt); ++ if (atomic_read(&sq->available_commands) == sq->el_count - 1) ++ complete(&ev->queue_empty_completion); ++} ++ ++void bce_vhci_event_queue_submit_pending(struct bce_vhci_event_queue *q, size_t count) ++{ ++ int idx; ++ struct bce_qe_submission *s; ++ while (count--) { ++ if (bce_reserve_submission(q->sq, NULL)) { ++ pr_err("bce-vhci: Failed to reserve an event queue submission\n"); ++ break; ++ } ++ idx = q->sq->tail; ++ s = bce_next_submission(q->sq); ++ bce_set_submission_single(s, ++ q->dma_addr + idx * sizeof(struct bce_vhci_message), sizeof(struct bce_vhci_message)); ++ } ++ bce_submit_to_device(q->sq); ++} ++ ++void bce_vhci_event_queue_pause(struct bce_vhci_event_queue *q) ++{ ++ unsigned long timeout; ++ reinit_completion(&q->queue_empty_completion); ++ if (bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, q->sq->qid)) ++ pr_warn("bce-vhci: failed to flush event queue\n"); ++ timeout = msecs_to_jiffies(5000); ++ while (atomic_read(&q->sq->available_commands) != q->sq->el_count - 1) { ++ timeout = wait_for_completion_timeout(&q->queue_empty_completion, timeout); ++ if (timeout == 0) { ++ pr_err("bce-vhci: waiting for queue to be flushed timed out\n"); ++ break; ++ } ++ } ++} ++ ++void bce_vhci_event_queue_resume(struct bce_vhci_event_queue *q) ++{ ++ if (atomic_read(&q->sq->available_commands) != q->sq->el_count - 1) { ++ pr_err("bce-vhci: resume of a queue with pending submissions\n"); ++ return; ++ } ++ bce_vhci_event_queue_submit_pending(q, VHCI_EVENT_PENDING_COUNT); ++} ++ ++void bce_vhci_command_queue_create(struct bce_vhci_command_queue *ret, struct bce_vhci_message_queue *mq) ++{ ++ ret->mq = mq; ++ ret->completion.result = NULL; ++ init_completion(&ret->completion.completion); ++ spin_lock_init(&ret->completion_lock); ++ mutex_init(&ret->mutex); ++} ++ ++void bce_vhci_command_queue_destroy(struct bce_vhci_command_queue *cq) ++{ ++ spin_lock(&cq->completion_lock); ++ if (cq->completion.result) { ++ memset(cq->completion.result, 0, sizeof(struct bce_vhci_message)); ++ cq->completion.result->status = BCE_VHCI_ABORT; ++ complete(&cq->completion.completion); ++ cq->completion.result = NULL; ++ } ++ spin_unlock(&cq->completion_lock); ++ mutex_lock(&cq->mutex); ++ mutex_unlock(&cq->mutex); ++ mutex_destroy(&cq->mutex); ++} ++ ++void bce_vhci_command_queue_deliver_completion(struct bce_vhci_command_queue *cq, struct bce_vhci_message *msg) ++{ ++ struct bce_vhci_command_queue_completion *c = &cq->completion; ++ ++ spin_lock(&cq->completion_lock); ++ if (c->result) { ++ *c->result = *msg; ++ complete(&c->completion); ++ c->result = NULL; ++ } ++ spin_unlock(&cq->completion_lock); ++} ++ ++static int __bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, ++ struct bce_vhci_message *res, unsigned long timeout) ++{ ++ int status; ++ struct bce_vhci_command_queue_completion *c; ++ struct bce_vhci_message creq; ++ c = &cq->completion; ++ ++ if ((status = bce_reserve_submission(cq->mq->sq, &timeout))) ++ return status; ++ ++ spin_lock(&cq->completion_lock); ++ c->result = res; ++ reinit_completion(&c->completion); ++ spin_unlock(&cq->completion_lock); ++ ++ bce_vhci_message_queue_write(cq->mq, req); ++ ++ if (!wait_for_completion_timeout(&c->completion, timeout)) { ++ /* we ran out of time, send cancellation */ ++ pr_debug("bce-vhci: command timed out req=%x\n", req->cmd); ++ if ((status = bce_reserve_submission(cq->mq->sq, &timeout))) ++ return status; ++ ++ creq = *req; ++ creq.cmd |= 0x4000; ++ bce_vhci_message_queue_write(cq->mq, &creq); ++ ++ if (!wait_for_completion_timeout(&c->completion, 1000)) { ++ pr_err("bce-vhci: Possible desync, cmd cancel timed out\n"); ++ ++ spin_lock(&cq->completion_lock); ++ c->result = NULL; ++ spin_unlock(&cq->completion_lock); ++ return -ETIMEDOUT; ++ } ++ if ((res->cmd & ~0x8000) == creq.cmd) ++ return -ETIMEDOUT; ++ /* reply for the previous command most likely arrived */ ++ } ++ ++ if ((res->cmd & ~0x8000) != req->cmd) { ++ pr_err("bce-vhci: Possible desync, cmd reply mismatch req=%x, res=%x\n", req->cmd, res->cmd); ++ return -EIO; ++ } ++ if (res->status == BCE_VHCI_SUCCESS) ++ return 0; ++ return res->status; ++} ++ ++int bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, ++ struct bce_vhci_message *res, unsigned long timeout) ++{ ++ int status; ++ mutex_lock(&cq->mutex); ++ status = __bce_vhci_command_queue_execute(cq, req, res, timeout); ++ mutex_unlock(&cq->mutex); ++ return status; ++} +diff --git a/drivers/staging/apple-bce/vhci/queue.h b/drivers/staging/apple-bce/vhci/queue.h +new file mode 100644 +index 000000000000..adb705b6ba1d +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/queue.h +@@ -0,0 +1,76 @@ ++#ifndef BCE_VHCI_QUEUE_H ++#define BCE_VHCI_QUEUE_H ++ ++#include ++#include "../queue.h" ++ ++#define VHCI_EVENT_QUEUE_EL_COUNT 256 ++#define VHCI_EVENT_PENDING_COUNT 32 ++ ++struct bce_vhci; ++struct bce_vhci_event_queue; ++ ++enum bce_vhci_message_status { ++ BCE_VHCI_SUCCESS = 1, ++ BCE_VHCI_ERROR = 2, ++ BCE_VHCI_USB_PIPE_STALL = 3, ++ BCE_VHCI_ABORT = 4, ++ BCE_VHCI_BAD_ARGUMENT = 5, ++ BCE_VHCI_OVERRUN = 6, ++ BCE_VHCI_INTERNAL_ERROR = 7, ++ BCE_VHCI_NO_POWER = 8, ++ BCE_VHCI_UNSUPPORTED = 9 ++}; ++struct bce_vhci_message { ++ u16 cmd; ++ u16 status; // bce_vhci_message_status ++ u32 param1; ++ u64 param2; ++}; ++ ++struct bce_vhci_message_queue { ++ struct bce_queue_cq *cq; ++ struct bce_queue_sq *sq; ++ struct bce_vhci_message *data; ++ dma_addr_t dma_addr; ++}; ++typedef void (*bce_vhci_event_queue_callback)(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); ++struct bce_vhci_event_queue { ++ struct bce_vhci *vhci; ++ struct bce_queue_sq *sq; ++ struct bce_vhci_message *data; ++ dma_addr_t dma_addr; ++ bce_vhci_event_queue_callback cb; ++ struct completion queue_empty_completion; ++}; ++struct bce_vhci_command_queue_completion { ++ struct bce_vhci_message *result; ++ struct completion completion; ++}; ++struct bce_vhci_command_queue { ++ struct bce_vhci_message_queue *mq; ++ struct bce_vhci_command_queue_completion completion; ++ struct spinlock completion_lock; ++ struct mutex mutex; ++}; ++ ++int bce_vhci_message_queue_create(struct bce_vhci *vhci, struct bce_vhci_message_queue *ret, const char *name); ++void bce_vhci_message_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_message_queue *q); ++void bce_vhci_message_queue_write(struct bce_vhci_message_queue *q, struct bce_vhci_message *req); ++ ++int __bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, ++ bce_sq_completion compl); ++int bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, ++ bce_vhci_event_queue_callback cb); ++void bce_vhci_event_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_event_queue *q); ++void bce_vhci_event_queue_submit_pending(struct bce_vhci_event_queue *q, size_t count); ++void bce_vhci_event_queue_pause(struct bce_vhci_event_queue *q); ++void bce_vhci_event_queue_resume(struct bce_vhci_event_queue *q); ++ ++void bce_vhci_command_queue_create(struct bce_vhci_command_queue *ret, struct bce_vhci_message_queue *mq); ++void bce_vhci_command_queue_destroy(struct bce_vhci_command_queue *cq); ++int bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, ++ struct bce_vhci_message *res, unsigned long timeout); ++void bce_vhci_command_queue_deliver_completion(struct bce_vhci_command_queue *cq, struct bce_vhci_message *msg); ++ ++#endif //BCE_VHCI_QUEUE_H +diff --git a/drivers/staging/apple-bce/vhci/transfer.c b/drivers/staging/apple-bce/vhci/transfer.c +new file mode 100644 +index 000000000000..8226363d69c8 +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/transfer.c +@@ -0,0 +1,661 @@ ++#include "transfer.h" ++#include "../queue.h" ++#include "vhci.h" ++#include "../apple_bce.h" ++#include ++ ++static void bce_vhci_transfer_queue_completion(struct bce_queue_sq *sq); ++static void bce_vhci_transfer_queue_giveback(struct bce_vhci_transfer_queue *q); ++static void bce_vhci_transfer_queue_remove_pending(struct bce_vhci_transfer_queue *q); ++ ++static int bce_vhci_urb_init(struct bce_vhci_urb *vurb); ++static int bce_vhci_urb_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg); ++static int bce_vhci_urb_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c); ++ ++static void bce_vhci_transfer_queue_reset_w(struct work_struct *work); ++ ++void bce_vhci_create_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q, ++ struct usb_host_endpoint *endp, bce_vhci_device_t dev_addr, enum dma_data_direction dir) ++{ ++ char name[0x21]; ++ INIT_LIST_HEAD(&q->evq); ++ INIT_LIST_HEAD(&q->giveback_urb_list); ++ spin_lock_init(&q->urb_lock); ++ mutex_init(&q->pause_lock); ++ q->vhci = vhci; ++ q->endp = endp; ++ q->dev_addr = dev_addr; ++ q->endp_addr = (u8) (endp->desc.bEndpointAddress & 0x8F); ++ q->state = BCE_VHCI_ENDPOINT_ACTIVE; ++ q->active = true; ++ q->stalled = false; ++ q->max_active_requests = 1; ++ if (usb_endpoint_type(&endp->desc) == USB_ENDPOINT_XFER_BULK) ++ q->max_active_requests = BCE_VHCI_BULK_MAX_ACTIVE_URBS; ++ q->remaining_active_requests = q->max_active_requests; ++ q->cq = bce_create_cq(vhci->dev, 0x100); ++ INIT_WORK(&q->w_reset, bce_vhci_transfer_queue_reset_w); ++ q->sq_in = NULL; ++ if (dir == DMA_FROM_DEVICE || dir == DMA_BIDIRECTIONAL) { ++ snprintf(name, sizeof(name), "VHC1-%i-%02x", dev_addr, 0x80 | usb_endpoint_num(&endp->desc)); ++ q->sq_in = bce_create_sq(vhci->dev, q->cq, name, 0x100, DMA_FROM_DEVICE, ++ bce_vhci_transfer_queue_completion, q); ++ } ++ q->sq_out = NULL; ++ if (dir == DMA_TO_DEVICE || dir == DMA_BIDIRECTIONAL) { ++ snprintf(name, sizeof(name), "VHC1-%i-%02x", dev_addr, usb_endpoint_num(&endp->desc)); ++ q->sq_out = bce_create_sq(vhci->dev, q->cq, name, 0x100, DMA_TO_DEVICE, ++ bce_vhci_transfer_queue_completion, q); ++ } ++} ++ ++void bce_vhci_destroy_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q) ++{ ++ bce_vhci_transfer_queue_giveback(q); ++ bce_vhci_transfer_queue_remove_pending(q); ++ if (q->sq_in) ++ bce_destroy_sq(vhci->dev, q->sq_in); ++ if (q->sq_out) ++ bce_destroy_sq(vhci->dev, q->sq_out); ++ bce_destroy_cq(vhci->dev, q->cq); ++} ++ ++static inline bool bce_vhci_transfer_queue_can_init_urb(struct bce_vhci_transfer_queue *q) ++{ ++ return q->remaining_active_requests > 0; ++} ++ ++static void bce_vhci_transfer_queue_defer_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg) ++{ ++ struct bce_vhci_list_message *lm; ++ lm = kmalloc(sizeof(struct bce_vhci_list_message), GFP_KERNEL); ++ INIT_LIST_HEAD(&lm->list); ++ lm->msg = *msg; ++ list_add_tail(&lm->list, &q->evq); ++} ++ ++static void bce_vhci_transfer_queue_giveback(struct bce_vhci_transfer_queue *q) ++{ ++ unsigned long flags; ++ struct urb *urb; ++ spin_lock_irqsave(&q->urb_lock, flags); ++ while (!list_empty(&q->giveback_urb_list)) { ++ urb = list_first_entry(&q->giveback_urb_list, struct urb, urb_list); ++ list_del(&urb->urb_list); ++ ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ usb_hcd_giveback_urb(q->vhci->hcd, urb, urb->status); ++ spin_lock_irqsave(&q->urb_lock, flags); ++ } ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++} ++ ++static void bce_vhci_transfer_queue_init_pending_urbs(struct bce_vhci_transfer_queue *q); ++ ++static void bce_vhci_transfer_queue_deliver_pending(struct bce_vhci_transfer_queue *q) ++{ ++ struct urb *urb; ++ struct bce_vhci_list_message *lm; ++ ++ while (!list_empty(&q->endp->urb_list) && !list_empty(&q->evq)) { ++ urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); ++ ++ lm = list_first_entry(&q->evq, struct bce_vhci_list_message, list); ++ if (bce_vhci_urb_update(urb->hcpriv, &lm->msg) == -EAGAIN) ++ break; ++ list_del(&lm->list); ++ kfree(lm); ++ } ++ ++ /* some of the URBs could have been completed, so initialize more URBs if possible */ ++ bce_vhci_transfer_queue_init_pending_urbs(q); ++} ++ ++static void bce_vhci_transfer_queue_remove_pending(struct bce_vhci_transfer_queue *q) ++{ ++ unsigned long flags; ++ struct bce_vhci_list_message *lm; ++ spin_lock_irqsave(&q->urb_lock, flags); ++ while (!list_empty(&q->evq)) { ++ lm = list_first_entry(&q->evq, struct bce_vhci_list_message, list); ++ list_del(&lm->list); ++ kfree(lm); ++ } ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++} ++ ++void bce_vhci_transfer_queue_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg) ++{ ++ unsigned long flags; ++ struct bce_vhci_urb *turb; ++ struct urb *urb; ++ spin_lock_irqsave(&q->urb_lock, flags); ++ bce_vhci_transfer_queue_deliver_pending(q); ++ ++ if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && ++ (!list_empty(&q->evq) || list_empty(&q->endp->urb_list))) { ++ bce_vhci_transfer_queue_defer_event(q, msg); ++ goto complete; ++ } ++ if (list_empty(&q->endp->urb_list)) { ++ pr_err("bce-vhci: [%02x] Unexpected transfer queue event\n", q->endp_addr); ++ goto complete; ++ } ++ urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); ++ turb = urb->hcpriv; ++ if (bce_vhci_urb_update(turb, msg) == -EAGAIN) { ++ bce_vhci_transfer_queue_defer_event(q, msg); ++ } else { ++ bce_vhci_transfer_queue_init_pending_urbs(q); ++ } ++ ++complete: ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ bce_vhci_transfer_queue_giveback(q); ++} ++ ++static void bce_vhci_transfer_queue_completion(struct bce_queue_sq *sq) ++{ ++ unsigned long flags; ++ struct bce_sq_completion_data *c; ++ struct urb *urb; ++ struct bce_vhci_transfer_queue *q = sq->userdata; ++ spin_lock_irqsave(&q->urb_lock, flags); ++ while ((c = bce_next_completion(sq))) { ++ if (c->status == BCE_COMPLETION_ABORTED) { /* We flushed the queue */ ++ pr_debug("bce-vhci: [%02x] Got an abort completion\n", q->endp_addr); ++ bce_notify_submission_complete(sq); ++ continue; ++ } ++ if (list_empty(&q->endp->urb_list)) { ++ pr_err("bce-vhci: [%02x] Got a completion while no requests are pending\n", q->endp_addr); ++ continue; ++ } ++ pr_debug("bce-vhci: [%02x] Got a transfer queue completion\n", q->endp_addr); ++ urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); ++ bce_vhci_urb_transfer_completion(urb->hcpriv, c); ++ bce_notify_submission_complete(sq); ++ } ++ bce_vhci_transfer_queue_deliver_pending(q); ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ bce_vhci_transfer_queue_giveback(q); ++} ++ ++int bce_vhci_transfer_queue_do_pause(struct bce_vhci_transfer_queue *q) ++{ ++ unsigned long flags; ++ int status; ++ u8 endp_addr = (u8) (q->endp->desc.bEndpointAddress & 0x8F); ++ spin_lock_irqsave(&q->urb_lock, flags); ++ q->active = false; ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ if (q->sq_out) { ++ pr_err("bce-vhci: Not implemented: wait for pending output requests\n"); ++ } ++ bce_vhci_transfer_queue_remove_pending(q); ++ if ((status = bce_vhci_cmd_endpoint_set_state( ++ &q->vhci->cq, q->dev_addr, endp_addr, BCE_VHCI_ENDPOINT_PAUSED, &q->state))) ++ return status; ++ if (q->state != BCE_VHCI_ENDPOINT_PAUSED) ++ return -EINVAL; ++ if (q->sq_in) ++ bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_in->qid); ++ if (q->sq_out) ++ bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_out->qid); ++ return 0; ++} ++ ++static void bce_vhci_urb_resume(struct bce_vhci_urb *urb); ++ ++int bce_vhci_transfer_queue_do_resume(struct bce_vhci_transfer_queue *q) ++{ ++ unsigned long flags; ++ int status; ++ struct urb *urb, *urbt; ++ struct bce_vhci_urb *vurb; ++ u8 endp_addr = (u8) (q->endp->desc.bEndpointAddress & 0x8F); ++ if ((status = bce_vhci_cmd_endpoint_set_state( ++ &q->vhci->cq, q->dev_addr, endp_addr, BCE_VHCI_ENDPOINT_ACTIVE, &q->state))) ++ return status; ++ if (q->state != BCE_VHCI_ENDPOINT_ACTIVE) ++ return -EINVAL; ++ spin_lock_irqsave(&q->urb_lock, flags); ++ q->active = true; ++ list_for_each_entry_safe(urb, urbt, &q->endp->urb_list, urb_list) { ++ vurb = urb->hcpriv; ++ if (vurb->state == BCE_VHCI_URB_INIT_PENDING) { ++ if (!bce_vhci_transfer_queue_can_init_urb(q)) ++ break; ++ bce_vhci_urb_init(vurb); ++ } else { ++ bce_vhci_urb_resume(vurb); ++ } ++ } ++ bce_vhci_transfer_queue_deliver_pending(q); ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ return 0; ++} ++ ++int bce_vhci_transfer_queue_pause(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src) ++{ ++ int ret = 0; ++ mutex_lock(&q->pause_lock); ++ if ((q->paused_by & src) != src) { ++ if (!q->paused_by) ++ ret = bce_vhci_transfer_queue_do_pause(q); ++ if (!ret) ++ q->paused_by |= src; ++ } ++ mutex_unlock(&q->pause_lock); ++ return ret; ++} ++ ++int bce_vhci_transfer_queue_resume(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src) ++{ ++ int ret = 0; ++ mutex_lock(&q->pause_lock); ++ if (q->paused_by & src) { ++ if (!(q->paused_by & ~src)) ++ ret = bce_vhci_transfer_queue_do_resume(q); ++ if (!ret) ++ q->paused_by &= ~src; ++ } ++ mutex_unlock(&q->pause_lock); ++ return ret; ++} ++ ++static void bce_vhci_transfer_queue_reset_w(struct work_struct *work) ++{ ++ unsigned long flags; ++ struct bce_vhci_transfer_queue *q = container_of(work, struct bce_vhci_transfer_queue, w_reset); ++ ++ mutex_lock(&q->pause_lock); ++ spin_lock_irqsave(&q->urb_lock, flags); ++ if (!q->stalled) { ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ mutex_unlock(&q->pause_lock); ++ return; ++ } ++ q->active = false; ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ q->paused_by |= BCE_VHCI_PAUSE_INTERNAL_WQ; ++ bce_vhci_transfer_queue_remove_pending(q); ++ if (q->sq_in) ++ bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_in->qid); ++ if (q->sq_out) ++ bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_out->qid); ++ bce_vhci_cmd_endpoint_reset(&q->vhci->cq, q->dev_addr, (u8) (q->endp->desc.bEndpointAddress & 0x8F)); ++ spin_lock_irqsave(&q->urb_lock, flags); ++ q->stalled = false; ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ mutex_unlock(&q->pause_lock); ++ bce_vhci_transfer_queue_resume(q, BCE_VHCI_PAUSE_INTERNAL_WQ); ++} ++ ++void bce_vhci_transfer_queue_request_reset(struct bce_vhci_transfer_queue *q) ++{ ++ queue_work(q->vhci->tq_state_wq, &q->w_reset); ++} ++ ++static void bce_vhci_transfer_queue_init_pending_urbs(struct bce_vhci_transfer_queue *q) ++{ ++ struct urb *urb, *urbt; ++ struct bce_vhci_urb *vurb; ++ list_for_each_entry_safe(urb, urbt, &q->endp->urb_list, urb_list) { ++ vurb = urb->hcpriv; ++ if (!bce_vhci_transfer_queue_can_init_urb(q)) ++ break; ++ if (vurb->state == BCE_VHCI_URB_INIT_PENDING) ++ bce_vhci_urb_init(vurb); ++ } ++} ++ ++ ++ ++static int bce_vhci_urb_data_start(struct bce_vhci_urb *urb, unsigned long *timeout); ++ ++int bce_vhci_urb_create(struct bce_vhci_transfer_queue *q, struct urb *urb) ++{ ++ unsigned long flags; ++ int status = 0; ++ struct bce_vhci_urb *vurb; ++ vurb = kzalloc(sizeof(struct bce_vhci_urb), GFP_KERNEL); ++ urb->hcpriv = vurb; ++ ++ vurb->q = q; ++ vurb->urb = urb; ++ vurb->dir = usb_urb_dir_in(urb) ? DMA_FROM_DEVICE : DMA_TO_DEVICE; ++ vurb->is_control = (usb_endpoint_num(&urb->ep->desc) == 0); ++ ++ spin_lock_irqsave(&q->urb_lock, flags); ++ status = usb_hcd_link_urb_to_ep(q->vhci->hcd, urb); ++ if (status) { ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ urb->hcpriv = NULL; ++ kfree(vurb); ++ return status; ++ } ++ ++ if (q->active) { ++ if (bce_vhci_transfer_queue_can_init_urb(vurb->q)) ++ status = bce_vhci_urb_init(vurb); ++ else ++ vurb->state = BCE_VHCI_URB_INIT_PENDING; ++ } else { ++ if (q->stalled) ++ bce_vhci_transfer_queue_request_reset(q); ++ vurb->state = BCE_VHCI_URB_INIT_PENDING; ++ } ++ if (status) { ++ usb_hcd_unlink_urb_from_ep(q->vhci->hcd, urb); ++ urb->hcpriv = NULL; ++ kfree(vurb); ++ } else { ++ bce_vhci_transfer_queue_deliver_pending(q); ++ } ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ pr_debug("bce-vhci: [%02x] URB enqueued (dir = %s, size = %i)\n", q->endp_addr, ++ usb_urb_dir_in(urb) ? "IN" : "OUT", urb->transfer_buffer_length); ++ return status; ++} ++ ++static int bce_vhci_urb_init(struct bce_vhci_urb *vurb) ++{ ++ int status = 0; ++ ++ if (vurb->q->remaining_active_requests == 0) { ++ pr_err("bce-vhci: cannot init request (remaining_active_requests = 0)\n"); ++ return -EINVAL; ++ } ++ ++ if (vurb->is_control) { ++ vurb->state = BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST; ++ } else { ++ status = bce_vhci_urb_data_start(vurb, NULL); ++ } ++ ++ if (!status) { ++ --vurb->q->remaining_active_requests; ++ } ++ return status; ++} ++ ++static void bce_vhci_urb_complete(struct bce_vhci_urb *urb, int status) ++{ ++ struct bce_vhci_transfer_queue *q = urb->q; ++ struct bce_vhci *vhci = q->vhci; ++ struct urb *real_urb = urb->urb; ++ pr_debug("bce-vhci: [%02x] URB complete %i\n", q->endp_addr, status); ++ usb_hcd_unlink_urb_from_ep(vhci->hcd, real_urb); ++ real_urb->hcpriv = NULL; ++ real_urb->status = status; ++ if (urb->state != BCE_VHCI_URB_INIT_PENDING) ++ ++urb->q->remaining_active_requests; ++ kfree(urb); ++ list_add_tail(&real_urb->urb_list, &q->giveback_urb_list); ++} ++ ++int bce_vhci_urb_request_cancel(struct bce_vhci_transfer_queue *q, struct urb *urb, int status) ++{ ++ struct bce_vhci_urb *vurb; ++ unsigned long flags; ++ int ret; ++ ++ spin_lock_irqsave(&q->urb_lock, flags); ++ if ((ret = usb_hcd_check_unlink_urb(q->vhci->hcd, urb, status))) { ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ return ret; ++ } ++ ++ vurb = urb->hcpriv; ++ /* If the URB wasn't posted to the device yet, we can still remove it on the host without pausing the queue. */ ++ if (vurb->state != BCE_VHCI_URB_INIT_PENDING) { ++ pr_debug("bce-vhci: [%02x] Cancelling URB\n", q->endp_addr); ++ ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ bce_vhci_transfer_queue_pause(q, BCE_VHCI_PAUSE_INTERNAL_WQ); ++ spin_lock_irqsave(&q->urb_lock, flags); ++ ++ ++q->remaining_active_requests; ++ } ++ ++ usb_hcd_unlink_urb_from_ep(q->vhci->hcd, urb); ++ ++ spin_unlock_irqrestore(&q->urb_lock, flags); ++ ++ usb_hcd_giveback_urb(q->vhci->hcd, urb, status); ++ ++ if (vurb->state != BCE_VHCI_URB_INIT_PENDING) ++ bce_vhci_transfer_queue_resume(q, BCE_VHCI_PAUSE_INTERNAL_WQ); ++ ++ kfree(vurb); ++ ++ return 0; ++} ++ ++static int bce_vhci_urb_data_transfer_in(struct bce_vhci_urb *urb, unsigned long *timeout) ++{ ++ struct bce_vhci_message msg; ++ struct bce_qe_submission *s; ++ u32 tr_len; ++ int reservation1, reservation2 = -EFAULT; ++ ++ pr_debug("bce-vhci: [%02x] DMA from device %llx %x\n", urb->q->endp_addr, ++ (u64) urb->urb->transfer_dma, urb->urb->transfer_buffer_length); ++ ++ /* Reserve both a message and a submission, so we don't run into issues later. */ ++ reservation1 = bce_reserve_submission(urb->q->vhci->msg_asynchronous.sq, timeout); ++ if (!reservation1) ++ reservation2 = bce_reserve_submission(urb->q->sq_in, timeout); ++ if (reservation1 || reservation2) { ++ pr_err("bce-vhci: Failed to reserve a submission for URB data transfer\n"); ++ if (!reservation1) ++ bce_cancel_submission_reservation(urb->q->vhci->msg_asynchronous.sq); ++ return -ENOMEM; ++ } ++ ++ urb->send_offset = urb->receive_offset; ++ ++ tr_len = urb->urb->transfer_buffer_length - urb->send_offset; ++ ++ spin_lock(&urb->q->vhci->msg_asynchronous_lock); ++ msg.cmd = BCE_VHCI_CMD_TRANSFER_REQUEST; ++ msg.status = 0; ++ msg.param1 = ((urb->urb->ep->desc.bEndpointAddress & 0x8Fu) << 8) | urb->q->dev_addr; ++ msg.param2 = tr_len; ++ bce_vhci_message_queue_write(&urb->q->vhci->msg_asynchronous, &msg); ++ spin_unlock(&urb->q->vhci->msg_asynchronous_lock); ++ ++ s = bce_next_submission(urb->q->sq_in); ++ bce_set_submission_single(s, urb->urb->transfer_dma + urb->send_offset, tr_len); ++ bce_submit_to_device(urb->q->sq_in); ++ ++ urb->state = BCE_VHCI_URB_WAITING_FOR_COMPLETION; ++ return 0; ++} ++ ++static int bce_vhci_urb_data_start(struct bce_vhci_urb *urb, unsigned long *timeout) ++{ ++ if (urb->dir == DMA_TO_DEVICE) { ++ if (urb->urb->transfer_buffer_length > 0) ++ urb->state = BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST; ++ else ++ urb->state = BCE_VHCI_URB_DATA_TRANSFER_COMPLETE; ++ return 0; ++ } else { ++ return bce_vhci_urb_data_transfer_in(urb, timeout); ++ } ++} ++ ++static int bce_vhci_urb_send_out_data(struct bce_vhci_urb *urb, dma_addr_t addr, size_t size) ++{ ++ struct bce_qe_submission *s; ++ unsigned long timeout = 0; ++ if (bce_reserve_submission(urb->q->sq_out, &timeout)) { ++ pr_err("bce-vhci: Failed to reserve a submission for URB data transfer\n"); ++ return -EPIPE; ++ } ++ ++ pr_debug("bce-vhci: [%02x] DMA to device %llx %lx\n", urb->q->endp_addr, (u64) addr, size); ++ ++ s = bce_next_submission(urb->q->sq_out); ++ bce_set_submission_single(s, addr, size); ++ bce_submit_to_device(urb->q->sq_out); ++ return 0; ++} ++ ++static int bce_vhci_urb_data_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) ++{ ++ u32 tr_len; ++ int status; ++ if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST) { ++ if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST) { ++ tr_len = min(urb->urb->transfer_buffer_length - urb->send_offset, (u32) msg->param2); ++ if ((status = bce_vhci_urb_send_out_data(urb, urb->urb->transfer_dma + urb->send_offset, tr_len))) ++ return status; ++ urb->send_offset += tr_len; ++ urb->state = BCE_VHCI_URB_WAITING_FOR_COMPLETION; ++ return 0; ++ } ++ } ++ ++ /* 0x1000 in out queues aren't really unexpected */ ++ if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && urb->q->sq_out != NULL) ++ return -EAGAIN; ++ pr_err("bce-vhci: [%02x] %s URB unexpected message (state = %x, msg: %x %x %x %llx)\n", ++ urb->q->endp_addr, (urb->is_control ? "Control (data update)" : "Data"), urb->state, ++ msg->cmd, msg->status, msg->param1, msg->param2); ++ return -EAGAIN; ++} ++ ++static int bce_vhci_urb_data_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) ++{ ++ if (urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { ++ urb->receive_offset += c->data_size; ++ if (urb->dir == DMA_FROM_DEVICE || urb->receive_offset >= urb->urb->transfer_buffer_length) { ++ urb->urb->actual_length = (u32) urb->receive_offset; ++ urb->state = BCE_VHCI_URB_DATA_TRANSFER_COMPLETE; ++ if (!urb->is_control) { ++ bce_vhci_urb_complete(urb, 0); ++ return -ENOENT; ++ } ++ } ++ } else { ++ pr_err("bce-vhci: [%02x] Data URB unexpected completion\n", urb->q->endp_addr); ++ } ++ return 0; ++} ++ ++ ++static int bce_vhci_urb_control_check_status(struct bce_vhci_urb *urb) ++{ ++ struct bce_vhci_transfer_queue *q = urb->q; ++ if (urb->received_status == 0) ++ return 0; ++ if (urb->state == BCE_VHCI_URB_DATA_TRANSFER_COMPLETE || ++ (urb->received_status != BCE_VHCI_SUCCESS && urb->state != BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST && ++ urb->state != BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION)) { ++ urb->state = BCE_VHCI_URB_CONTROL_COMPLETE; ++ if (urb->received_status != BCE_VHCI_SUCCESS) { ++ pr_err("bce-vhci: [%02x] URB failed: %x\n", urb->q->endp_addr, urb->received_status); ++ urb->q->active = false; ++ urb->q->stalled = true; ++ bce_vhci_urb_complete(urb, -EPIPE); ++ if (!list_empty(&q->endp->urb_list)) ++ bce_vhci_transfer_queue_request_reset(q); ++ return -ENOENT; ++ } ++ bce_vhci_urb_complete(urb, 0); ++ return -ENOENT; ++ } ++ return 0; ++} ++ ++static int bce_vhci_urb_control_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) ++{ ++ int status; ++ if (msg->cmd == BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS) { ++ urb->received_status = msg->status; ++ return bce_vhci_urb_control_check_status(urb); ++ } ++ ++ if (urb->state == BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST) { ++ if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST) { ++ if (bce_vhci_urb_send_out_data(urb, urb->urb->setup_dma, sizeof(struct usb_ctrlrequest))) { ++ pr_err("bce-vhci: [%02x] Failed to start URB setup transfer\n", urb->q->endp_addr); ++ return 0; /* TODO: fail the URB? */ ++ } ++ urb->state = BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION; ++ pr_debug("bce-vhci: [%02x] Sent setup %llx\n", urb->q->endp_addr, urb->urb->setup_dma); ++ return 0; ++ } ++ } else if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST || ++ urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { ++ if ((status = bce_vhci_urb_data_update(urb, msg))) ++ return status; ++ return bce_vhci_urb_control_check_status(urb); ++ } ++ ++ /* 0x1000 in out queues aren't really unexpected */ ++ if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && urb->q->sq_out != NULL) ++ return -EAGAIN; ++ pr_err("bce-vhci: [%02x] Control URB unexpected message (state = %x, msg: %x %x %x %llx)\n", urb->q->endp_addr, ++ urb->state, msg->cmd, msg->status, msg->param1, msg->param2); ++ return -EAGAIN; ++} ++ ++static int bce_vhci_urb_control_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) ++{ ++ int status; ++ unsigned long timeout; ++ ++ if (urb->state == BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION) { ++ if (c->data_size != sizeof(struct usb_ctrlrequest)) ++ pr_err("bce-vhci: [%02x] transfer complete data size mistmatch for usb_ctrlrequest (%llx instead of %lx)\n", ++ urb->q->endp_addr, c->data_size, sizeof(struct usb_ctrlrequest)); ++ ++ timeout = 1000; ++ status = bce_vhci_urb_data_start(urb, &timeout); ++ if (status) { ++ bce_vhci_urb_complete(urb, status); ++ return -ENOENT; ++ } ++ return 0; ++ } else if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST || ++ urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { ++ if ((status = bce_vhci_urb_data_transfer_completion(urb, c))) ++ return status; ++ return bce_vhci_urb_control_check_status(urb); ++ } else { ++ pr_err("bce-vhci: [%02x] Control URB unexpected completion (state = %x)\n", urb->q->endp_addr, urb->state); ++ } ++ return 0; ++} ++ ++static int bce_vhci_urb_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) ++{ ++ if (urb->state == BCE_VHCI_URB_INIT_PENDING) ++ return -EAGAIN; ++ if (urb->is_control) ++ return bce_vhci_urb_control_update(urb, msg); ++ else ++ return bce_vhci_urb_data_update(urb, msg); ++} ++ ++static int bce_vhci_urb_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) ++{ ++ if (urb->is_control) ++ return bce_vhci_urb_control_transfer_completion(urb, c); ++ else ++ return bce_vhci_urb_data_transfer_completion(urb, c); ++} ++ ++static void bce_vhci_urb_resume(struct bce_vhci_urb *urb) ++{ ++ int status = 0; ++ if (urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { ++ status = bce_vhci_urb_data_transfer_in(urb, NULL); ++ } ++ if (status) ++ bce_vhci_urb_complete(urb, status); ++} +diff --git a/drivers/staging/apple-bce/vhci/transfer.h b/drivers/staging/apple-bce/vhci/transfer.h +new file mode 100644 +index 000000000000..89ecad6bcf8f +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/transfer.h +@@ -0,0 +1,73 @@ ++#ifndef BCEDRIVER_TRANSFER_H ++#define BCEDRIVER_TRANSFER_H ++ ++#include ++#include "queue.h" ++#include "command.h" ++#include "../queue.h" ++ ++struct bce_vhci_list_message { ++ struct list_head list; ++ struct bce_vhci_message msg; ++}; ++enum bce_vhci_pause_source { ++ BCE_VHCI_PAUSE_INTERNAL_WQ = 1, ++ BCE_VHCI_PAUSE_FIRMWARE = 2, ++ BCE_VHCI_PAUSE_SUSPEND = 4, ++ BCE_VHCI_PAUSE_SHUTDOWN = 8 ++}; ++struct bce_vhci_transfer_queue { ++ struct bce_vhci *vhci; ++ struct usb_host_endpoint *endp; ++ enum bce_vhci_endpoint_state state; ++ u32 max_active_requests, remaining_active_requests; ++ bool active, stalled; ++ u32 paused_by; ++ bce_vhci_device_t dev_addr; ++ u8 endp_addr; ++ struct bce_queue_cq *cq; ++ struct bce_queue_sq *sq_in; ++ struct bce_queue_sq *sq_out; ++ struct list_head evq; ++ struct spinlock urb_lock; ++ struct mutex pause_lock; ++ struct list_head giveback_urb_list; ++ ++ struct work_struct w_reset; ++}; ++enum bce_vhci_urb_state { ++ BCE_VHCI_URB_INIT_PENDING, ++ ++ BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST, ++ BCE_VHCI_URB_WAITING_FOR_COMPLETION, ++ BCE_VHCI_URB_DATA_TRANSFER_COMPLETE, ++ ++ BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST, ++ BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION, ++ BCE_VHCI_URB_CONTROL_COMPLETE ++}; ++struct bce_vhci_urb { ++ struct urb *urb; ++ struct bce_vhci_transfer_queue *q; ++ enum dma_data_direction dir; ++ bool is_control; ++ enum bce_vhci_urb_state state; ++ int received_status; ++ u32 send_offset; ++ u32 receive_offset; ++}; ++ ++void bce_vhci_create_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q, ++ struct usb_host_endpoint *endp, bce_vhci_device_t dev_addr, enum dma_data_direction dir); ++void bce_vhci_destroy_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q); ++void bce_vhci_transfer_queue_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg); ++int bce_vhci_transfer_queue_do_pause(struct bce_vhci_transfer_queue *q); ++int bce_vhci_transfer_queue_do_resume(struct bce_vhci_transfer_queue *q); ++int bce_vhci_transfer_queue_pause(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src); ++int bce_vhci_transfer_queue_resume(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src); ++void bce_vhci_transfer_queue_request_reset(struct bce_vhci_transfer_queue *q); ++ ++int bce_vhci_urb_create(struct bce_vhci_transfer_queue *q, struct urb *urb); ++int bce_vhci_urb_request_cancel(struct bce_vhci_transfer_queue *q, struct urb *urb, int status); ++ ++#endif //BCEDRIVER_TRANSFER_H +diff --git a/drivers/staging/apple-bce/vhci/vhci.c b/drivers/staging/apple-bce/vhci/vhci.c +new file mode 100644 +index 000000000000..eb26f55000d8 +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/vhci.c +@@ -0,0 +1,759 @@ ++#include "vhci.h" ++#include "../apple_bce.h" ++#include "command.h" ++#include ++#include ++#include ++#include ++ ++static dev_t bce_vhci_chrdev; ++static struct class *bce_vhci_class; ++static const struct hc_driver bce_vhci_driver; ++static u16 bce_vhci_port_mask = U16_MAX; ++ ++static int bce_vhci_create_event_queues(struct bce_vhci *vhci); ++static void bce_vhci_destroy_event_queues(struct bce_vhci *vhci); ++static int bce_vhci_create_message_queues(struct bce_vhci *vhci); ++static void bce_vhci_destroy_message_queues(struct bce_vhci *vhci); ++static void bce_vhci_handle_firmware_events_w(struct work_struct *ws); ++static void bce_vhci_firmware_event_completion(struct bce_queue_sq *sq); ++ ++int bce_vhci_create(struct apple_bce_device *dev, struct bce_vhci *vhci) ++{ ++ int status; ++ ++ spin_lock_init(&vhci->hcd_spinlock); ++ ++ vhci->dev = dev; ++ ++ vhci->vdevt = bce_vhci_chrdev; ++ vhci->vdev = device_create(bce_vhci_class, dev->dev, vhci->vdevt, NULL, "bce-vhci"); ++ if (IS_ERR_OR_NULL(vhci->vdev)) { ++ status = PTR_ERR(vhci->vdev); ++ goto fail_dev; ++ } ++ ++ if ((status = bce_vhci_create_message_queues(vhci))) ++ goto fail_mq; ++ if ((status = bce_vhci_create_event_queues(vhci))) ++ goto fail_eq; ++ ++ vhci->tq_state_wq = alloc_ordered_workqueue("bce-vhci-tq-state", 0); ++ INIT_WORK(&vhci->w_fw_events, bce_vhci_handle_firmware_events_w); ++ ++ vhci->hcd = usb_create_hcd(&bce_vhci_driver, vhci->vdev, "bce-vhci"); ++ if (!vhci->hcd) { ++ status = -ENOMEM; ++ goto fail_hcd; ++ } ++ vhci->hcd->self.sysdev = &dev->pci->dev; ++#if LINUX_VERSION_CODE < KERNEL_VERSION(5,4,0) ++ vhci->hcd->self.uses_dma = 1; ++#endif ++ *((struct bce_vhci **) vhci->hcd->hcd_priv) = vhci; ++ vhci->hcd->speed = HCD_USB2; ++ ++ if ((status = usb_add_hcd(vhci->hcd, 0, 0))) ++ goto fail_hcd; ++ ++ return 0; ++ ++fail_hcd: ++ bce_vhci_destroy_event_queues(vhci); ++fail_eq: ++ bce_vhci_destroy_message_queues(vhci); ++fail_mq: ++ device_destroy(bce_vhci_class, vhci->vdevt); ++fail_dev: ++ if (!status) ++ status = -EINVAL; ++ return status; ++} ++ ++void bce_vhci_destroy(struct bce_vhci *vhci) ++{ ++ usb_remove_hcd(vhci->hcd); ++ bce_vhci_destroy_event_queues(vhci); ++ bce_vhci_destroy_message_queues(vhci); ++ device_destroy(bce_vhci_class, vhci->vdevt); ++} ++ ++struct bce_vhci *bce_vhci_from_hcd(struct usb_hcd *hcd) ++{ ++ return *((struct bce_vhci **) hcd->hcd_priv); ++} ++ ++int bce_vhci_start(struct usb_hcd *hcd) ++{ ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ int status; ++ u16 port_mask = 0; ++ bce_vhci_port_t port_no = 0; ++ if ((status = bce_vhci_cmd_controller_enable(&vhci->cq, 1, &port_mask))) ++ return status; ++ vhci->port_mask = port_mask; ++ vhci->port_power_mask = 0; ++ if ((status = bce_vhci_cmd_controller_start(&vhci->cq))) ++ return status; ++ port_mask = vhci->port_mask; ++ while (port_mask) { ++ port_no += 1; ++ port_mask >>= 1; ++ } ++ vhci->port_count = port_no; ++ return 0; ++} ++ ++void bce_vhci_stop(struct usb_hcd *hcd) ++{ ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ bce_vhci_cmd_controller_disable(&vhci->cq); ++} ++ ++static int bce_vhci_hub_status_data(struct usb_hcd *hcd, char *buf) ++{ ++ return 0; ++} ++ ++static int bce_vhci_reset_device(struct bce_vhci *vhci, int index, u16 timeout); ++ ++static int bce_vhci_hub_control(struct usb_hcd *hcd, u16 typeReq, u16 wValue, u16 wIndex, char *buf, u16 wLength) ++{ ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ int status; ++ struct usb_hub_descriptor *hd; ++ struct usb_hub_status *hs; ++ struct usb_port_status *ps; ++ u32 port_status; ++ // pr_info("bce-vhci: bce_vhci_hub_control %x %i %i [bufl=%i]\n", typeReq, wValue, wIndex, wLength); ++ if (typeReq == GetHubDescriptor && wLength >= sizeof(struct usb_hub_descriptor)) { ++ hd = (struct usb_hub_descriptor *) buf; ++ memset(hd, 0, sizeof(*hd)); ++ hd->bDescLength = sizeof(struct usb_hub_descriptor); ++ hd->bDescriptorType = USB_DT_HUB; ++ hd->bNbrPorts = (u8) vhci->port_count; ++ hd->wHubCharacteristics = HUB_CHAR_INDV_PORT_LPSM | HUB_CHAR_INDV_PORT_OCPM; ++ hd->bPwrOn2PwrGood = 0; ++ hd->bHubContrCurrent = 0; ++ return 0; ++ } else if (typeReq == GetHubStatus && wLength >= sizeof(struct usb_hub_status)) { ++ hs = (struct usb_hub_status *) buf; ++ memset(hs, 0, sizeof(*hs)); ++ hs->wHubStatus = 0; ++ hs->wHubChange = 0; ++ return 0; ++ } else if (typeReq == GetPortStatus && wLength >= 4 /* usb 2.0 */) { ++ ps = (struct usb_port_status *) buf; ++ ps->wPortStatus = 0; ++ ps->wPortChange = 0; ++ ++ if (vhci->port_power_mask & BIT(wIndex)) ++ ps->wPortStatus |= USB_PORT_STAT_POWER; ++ ++ if (!(bce_vhci_port_mask & BIT(wIndex))) ++ return 0; ++ ++ if ((status = bce_vhci_cmd_port_status(&vhci->cq, (u8) wIndex, 0, &port_status))) ++ return status; ++ ++ if (port_status & 16) ++ ps->wPortStatus |= USB_PORT_STAT_ENABLE | USB_PORT_STAT_HIGH_SPEED; ++ if (port_status & 4) ++ ps->wPortStatus |= USB_PORT_STAT_CONNECTION; ++ if (port_status & 2) ++ ps->wPortStatus |= USB_PORT_STAT_OVERCURRENT; ++ if (port_status & 8) ++ ps->wPortStatus |= USB_PORT_STAT_RESET; ++ if (port_status & 0x60) ++ ps->wPortStatus |= USB_PORT_STAT_SUSPEND; ++ ++ if (port_status & 0x40000) ++ ps->wPortChange |= USB_PORT_STAT_C_CONNECTION; ++ ++ pr_debug("bce-vhci: Translated status %x to %x:%x\n", port_status, ps->wPortStatus, ps->wPortChange); ++ return 0; ++ } else if (typeReq == SetPortFeature) { ++ if (wValue == USB_PORT_FEAT_POWER) { ++ status = bce_vhci_cmd_port_power_on(&vhci->cq, (u8) wIndex); ++ /* As far as I am aware, power status is not part of the port status so store it separately */ ++ if (!status) ++ vhci->port_power_mask |= BIT(wIndex); ++ return status; ++ } ++ if (wValue == USB_PORT_FEAT_RESET) { ++ return bce_vhci_reset_device(vhci, wIndex, wValue); ++ } ++ if (wValue == USB_PORT_FEAT_SUSPEND) { ++ /* TODO: Am I supposed to also suspend the endpoints? */ ++ pr_debug("bce-vhci: Suspending port %i\n", wIndex); ++ return bce_vhci_cmd_port_suspend(&vhci->cq, (u8) wIndex); ++ } ++ } else if (typeReq == ClearPortFeature) { ++ if (wValue == USB_PORT_FEAT_ENABLE) ++ return bce_vhci_cmd_port_disable(&vhci->cq, (u8) wIndex); ++ if (wValue == USB_PORT_FEAT_POWER) { ++ status = bce_vhci_cmd_port_power_off(&vhci->cq, (u8) wIndex); ++ if (!status) ++ vhci->port_power_mask &= ~BIT(wIndex); ++ return status; ++ } ++ if (wValue == USB_PORT_FEAT_C_CONNECTION) ++ return bce_vhci_cmd_port_status(&vhci->cq, (u8) wIndex, 0x40000, &port_status); ++ if (wValue == USB_PORT_FEAT_C_RESET) { /* I don't think I can transfer it in any way */ ++ return 0; ++ } ++ if (wValue == USB_PORT_FEAT_SUSPEND) { ++ pr_debug("bce-vhci: Resuming port %i\n", wIndex); ++ return bce_vhci_cmd_port_resume(&vhci->cq, (u8) wIndex); ++ } ++ } ++ pr_err("bce-vhci: bce_vhci_hub_control unhandled request: %x %i %i [bufl=%i]\n", typeReq, wValue, wIndex, wLength); ++ dump_stack(); ++ return -EIO; ++} ++ ++static int bce_vhci_enable_device(struct usb_hcd *hcd, struct usb_device *udev) ++{ ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ struct bce_vhci_device *vdev; ++ bce_vhci_device_t devid; ++ pr_info("bce_vhci_enable_device\n"); ++ ++ if (vhci->port_to_device[udev->portnum]) ++ return 0; ++ ++ /* We need to early address the device */ ++ if (bce_vhci_cmd_device_create(&vhci->cq, udev->portnum, &devid)) ++ return -EIO; ++ ++ pr_info("bce_vhci_cmd_device_create %i -> %i\n", udev->portnum, devid); ++ ++ vdev = kzalloc(sizeof(struct bce_vhci_device), GFP_KERNEL); ++ vhci->port_to_device[udev->portnum] = devid; ++ vhci->devices[devid] = vdev; ++ ++ bce_vhci_create_transfer_queue(vhci, &vdev->tq[0], &udev->ep0, devid, DMA_BIDIRECTIONAL); ++ udev->ep0.hcpriv = &vdev->tq[0]; ++ vdev->tq_mask |= BIT(0); ++ ++ bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &udev->ep0.desc); ++ return 0; ++} ++ ++static int bce_vhci_address_device(struct usb_hcd *hcd, struct usb_device *udev, unsigned int timeout_ms) //TODO: follow timeout ++{ ++ /* This is the same as enable_device, but instead in the old scheme */ ++ return bce_vhci_enable_device(hcd, udev); ++} ++ ++static void bce_vhci_free_device(struct usb_hcd *hcd, struct usb_device *udev) ++{ ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ int i; ++ bce_vhci_device_t devid; ++ struct bce_vhci_device *dev; ++ pr_info("bce_vhci_free_device %i\n", udev->portnum); ++ if (!vhci->port_to_device[udev->portnum]) ++ return; ++ devid = vhci->port_to_device[udev->portnum]; ++ dev = vhci->devices[devid]; ++ for (i = 0; i < 32; i++) { ++ if (dev->tq_mask & BIT(i)) { ++ bce_vhci_transfer_queue_pause(&dev->tq[i], BCE_VHCI_PAUSE_SHUTDOWN); ++ bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) i); ++ bce_vhci_destroy_transfer_queue(vhci, &dev->tq[i]); ++ } ++ } ++ vhci->devices[devid] = NULL; ++ vhci->port_to_device[udev->portnum] = 0; ++ bce_vhci_cmd_device_destroy(&vhci->cq, devid); ++ kfree(dev); ++} ++ ++static int bce_vhci_reset_device(struct bce_vhci *vhci, int index, u16 timeout) ++{ ++ struct bce_vhci_device *dev = NULL; ++ bce_vhci_device_t devid; ++ int i; ++ int status; ++ enum dma_data_direction dir; ++ pr_info("bce_vhci_reset_device %i\n", index); ++ ++ devid = vhci->port_to_device[index]; ++ if (devid) { ++ dev = vhci->devices[devid]; ++ ++ for (i = 0; i < 32; i++) { ++ if (dev->tq_mask & BIT(i)) { ++ bce_vhci_transfer_queue_pause(&dev->tq[i], BCE_VHCI_PAUSE_SHUTDOWN); ++ bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) i); ++ bce_vhci_destroy_transfer_queue(vhci, &dev->tq[i]); ++ } ++ } ++ vhci->devices[devid] = NULL; ++ vhci->port_to_device[index] = 0; ++ bce_vhci_cmd_device_destroy(&vhci->cq, devid); ++ } ++ status = bce_vhci_cmd_port_reset(&vhci->cq, (u8) index, timeout); ++ ++ if (dev) { ++ if ((status = bce_vhci_cmd_device_create(&vhci->cq, index, &devid))) ++ return status; ++ vhci->devices[devid] = dev; ++ vhci->port_to_device[index] = devid; ++ ++ for (i = 0; i < 32; i++) { ++ if (dev->tq_mask & BIT(i)) { ++ dir = usb_endpoint_dir_in(&dev->tq[i].endp->desc) ? DMA_FROM_DEVICE : DMA_TO_DEVICE; ++ if (i == 0) ++ dir = DMA_BIDIRECTIONAL; ++ bce_vhci_create_transfer_queue(vhci, &dev->tq[i], dev->tq[i].endp, devid, dir); ++ bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &dev->tq[i].endp->desc); ++ } ++ } ++ } ++ ++ return status; ++} ++ ++static int bce_vhci_check_bandwidth(struct usb_hcd *hcd, struct usb_device *udev) ++{ ++ return 0; ++} ++ ++static int bce_vhci_get_frame_number(struct usb_hcd *hcd) ++{ ++ return 0; ++} ++ ++static int bce_vhci_bus_suspend(struct usb_hcd *hcd) ++{ ++ int i, j; ++ int status; ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ pr_info("bce_vhci: suspend started\n"); ++ ++ pr_info("bce_vhci: suspend endpoints\n"); ++ for (i = 0; i < 16; i++) { ++ if (!vhci->port_to_device[i]) ++ continue; ++ for (j = 0; j < 32; j++) { ++ if (!(vhci->devices[vhci->port_to_device[i]]->tq_mask & BIT(j))) ++ continue; ++ bce_vhci_transfer_queue_pause(&vhci->devices[vhci->port_to_device[i]]->tq[j], ++ BCE_VHCI_PAUSE_SUSPEND); ++ } ++ } ++ ++ pr_info("bce_vhci: suspend ports\n"); ++ for (i = 0; i < 16; i++) { ++ if (!vhci->port_to_device[i]) ++ continue; ++ bce_vhci_cmd_port_suspend(&vhci->cq, i); ++ } ++ pr_info("bce_vhci: suspend controller\n"); ++ if ((status = bce_vhci_cmd_controller_pause(&vhci->cq))) ++ return status; ++ ++ bce_vhci_event_queue_pause(&vhci->ev_commands); ++ bce_vhci_event_queue_pause(&vhci->ev_system); ++ bce_vhci_event_queue_pause(&vhci->ev_isochronous); ++ bce_vhci_event_queue_pause(&vhci->ev_interrupt); ++ bce_vhci_event_queue_pause(&vhci->ev_asynchronous); ++ pr_info("bce_vhci: suspend done\n"); ++ return 0; ++} ++ ++static int bce_vhci_bus_resume(struct usb_hcd *hcd) ++{ ++ int i, j; ++ int status; ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ pr_info("bce_vhci: resume started\n"); ++ ++ bce_vhci_event_queue_resume(&vhci->ev_system); ++ bce_vhci_event_queue_resume(&vhci->ev_isochronous); ++ bce_vhci_event_queue_resume(&vhci->ev_interrupt); ++ bce_vhci_event_queue_resume(&vhci->ev_asynchronous); ++ bce_vhci_event_queue_resume(&vhci->ev_commands); ++ ++ pr_info("bce_vhci: resume controller\n"); ++ if ((status = bce_vhci_cmd_controller_start(&vhci->cq))) ++ return status; ++ ++ pr_info("bce_vhci: resume ports\n"); ++ for (i = 0; i < 16; i++) { ++ if (!vhci->port_to_device[i]) ++ continue; ++ bce_vhci_cmd_port_resume(&vhci->cq, i); ++ } ++ pr_info("bce_vhci: resume endpoints\n"); ++ for (i = 0; i < 16; i++) { ++ if (!vhci->port_to_device[i]) ++ continue; ++ for (j = 0; j < 32; j++) { ++ if (!(vhci->devices[vhci->port_to_device[i]]->tq_mask & BIT(j))) ++ continue; ++ bce_vhci_transfer_queue_resume(&vhci->devices[vhci->port_to_device[i]]->tq[j], ++ BCE_VHCI_PAUSE_SUSPEND); ++ } ++ } ++ ++ pr_info("bce_vhci: resume done\n"); ++ return 0; ++} ++ ++static int bce_vhci_urb_enqueue(struct usb_hcd *hcd, struct urb *urb, gfp_t mem_flags) ++{ ++ struct bce_vhci_transfer_queue *q = urb->ep->hcpriv; ++ pr_debug("bce_vhci_urb_enqueue %i:%x\n", q->dev_addr, urb->ep->desc.bEndpointAddress); ++ if (!q) ++ return -ENOENT; ++ return bce_vhci_urb_create(q, urb); ++} ++ ++static int bce_vhci_urb_dequeue(struct usb_hcd *hcd, struct urb *urb, int status) ++{ ++ struct bce_vhci_transfer_queue *q = urb->ep->hcpriv; ++ pr_debug("bce_vhci_urb_dequeue %x\n", urb->ep->desc.bEndpointAddress); ++ return bce_vhci_urb_request_cancel(q, urb, status); ++} ++ ++static void bce_vhci_endpoint_reset(struct usb_hcd *hcd, struct usb_host_endpoint *ep) ++{ ++ struct bce_vhci_transfer_queue *q = ep->hcpriv; ++ pr_debug("bce_vhci_endpoint_reset\n"); ++ if (q) ++ bce_vhci_transfer_queue_request_reset(q); ++} ++ ++static u8 bce_vhci_endpoint_index(u8 addr) ++{ ++ if (addr & 0x80) ++ return (u8) (0x10 + (addr & 0xf)); ++ return (u8) (addr & 0xf); ++} ++ ++static int bce_vhci_add_endpoint(struct usb_hcd *hcd, struct usb_device *udev, struct usb_host_endpoint *endp) ++{ ++ u8 endp_index = bce_vhci_endpoint_index(endp->desc.bEndpointAddress); ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ bce_vhci_device_t devid = vhci->port_to_device[udev->portnum]; ++ struct bce_vhci_device *vdev = vhci->devices[devid]; ++ pr_debug("bce_vhci_add_endpoint %x/%x:%x\n", udev->portnum, devid, endp_index); ++ ++ if (udev->bus->root_hub == udev) /* The USB hub */ ++ return 0; ++ if (vdev == NULL) ++ return -ENODEV; ++ if (vdev->tq_mask & BIT(endp_index)) { ++ endp->hcpriv = &vdev->tq[endp_index]; ++ return 0; ++ } ++ ++ bce_vhci_create_transfer_queue(vhci, &vdev->tq[endp_index], endp, devid, ++ usb_endpoint_dir_in(&endp->desc) ? DMA_FROM_DEVICE : DMA_TO_DEVICE); ++ endp->hcpriv = &vdev->tq[endp_index]; ++ vdev->tq_mask |= BIT(endp_index); ++ ++ bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &endp->desc); ++ return 0; ++} ++ ++static int bce_vhci_drop_endpoint(struct usb_hcd *hcd, struct usb_device *udev, struct usb_host_endpoint *endp) ++{ ++ u8 endp_index = bce_vhci_endpoint_index(endp->desc.bEndpointAddress); ++ struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); ++ bce_vhci_device_t devid = vhci->port_to_device[udev->portnum]; ++ struct bce_vhci_transfer_queue *q = endp->hcpriv; ++ struct bce_vhci_device *vdev = vhci->devices[devid]; ++ pr_info("bce_vhci_drop_endpoint %x:%x\n", udev->portnum, endp_index); ++ if (!q) { ++ if (vdev && vdev->tq_mask & BIT(endp_index)) { ++ pr_err("something deleted the hcpriv?\n"); ++ q = &vdev->tq[endp_index]; ++ } else { ++ return 0; ++ } ++ } ++ ++ bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) (endp->desc.bEndpointAddress & 0x8Fu)); ++ vhci->devices[devid]->tq_mask &= ~BIT(endp_index); ++ bce_vhci_destroy_transfer_queue(vhci, q); ++ return 0; ++} ++ ++static int bce_vhci_create_message_queues(struct bce_vhci *vhci) ++{ ++ if (bce_vhci_message_queue_create(vhci, &vhci->msg_commands, "VHC1HostCommands") || ++ bce_vhci_message_queue_create(vhci, &vhci->msg_system, "VHC1HostSystemEvents") || ++ bce_vhci_message_queue_create(vhci, &vhci->msg_isochronous, "VHC1HostIsochronousEvents") || ++ bce_vhci_message_queue_create(vhci, &vhci->msg_interrupt, "VHC1HostInterruptEvents") || ++ bce_vhci_message_queue_create(vhci, &vhci->msg_asynchronous, "VHC1HostAsynchronousEvents")) { ++ bce_vhci_destroy_message_queues(vhci); ++ return -EINVAL; ++ } ++ spin_lock_init(&vhci->msg_asynchronous_lock); ++ bce_vhci_command_queue_create(&vhci->cq, &vhci->msg_commands); ++ return 0; ++} ++ ++static void bce_vhci_destroy_message_queues(struct bce_vhci *vhci) ++{ ++ bce_vhci_command_queue_destroy(&vhci->cq); ++ bce_vhci_message_queue_destroy(vhci, &vhci->msg_commands); ++ bce_vhci_message_queue_destroy(vhci, &vhci->msg_system); ++ bce_vhci_message_queue_destroy(vhci, &vhci->msg_isochronous); ++ bce_vhci_message_queue_destroy(vhci, &vhci->msg_interrupt); ++ bce_vhci_message_queue_destroy(vhci, &vhci->msg_asynchronous); ++} ++ ++static void bce_vhci_handle_system_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); ++static void bce_vhci_handle_usb_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); ++ ++static int bce_vhci_create_event_queues(struct bce_vhci *vhci) ++{ ++ vhci->ev_cq = bce_create_cq(vhci->dev, 0x100); ++ if (!vhci->ev_cq) ++ return -EINVAL; ++#define CREATE_EVENT_QUEUE(field, name, cb) bce_vhci_event_queue_create(vhci, &vhci->field, name, cb) ++ if (__bce_vhci_event_queue_create(vhci, &vhci->ev_commands, "VHC1FirmwareCommands", ++ bce_vhci_firmware_event_completion) || ++ CREATE_EVENT_QUEUE(ev_system, "VHC1FirmwareSystemEvents", bce_vhci_handle_system_event) || ++ CREATE_EVENT_QUEUE(ev_isochronous, "VHC1FirmwareIsochronousEvents", bce_vhci_handle_usb_event) || ++ CREATE_EVENT_QUEUE(ev_interrupt, "VHC1FirmwareInterruptEvents", bce_vhci_handle_usb_event) || ++ CREATE_EVENT_QUEUE(ev_asynchronous, "VHC1FirmwareAsynchronousEvents", bce_vhci_handle_usb_event)) { ++ bce_vhci_destroy_event_queues(vhci); ++ return -EINVAL; ++ } ++#undef CREATE_EVENT_QUEUE ++ return 0; ++} ++ ++static void bce_vhci_destroy_event_queues(struct bce_vhci *vhci) ++{ ++ bce_vhci_event_queue_destroy(vhci, &vhci->ev_commands); ++ bce_vhci_event_queue_destroy(vhci, &vhci->ev_system); ++ bce_vhci_event_queue_destroy(vhci, &vhci->ev_isochronous); ++ bce_vhci_event_queue_destroy(vhci, &vhci->ev_interrupt); ++ bce_vhci_event_queue_destroy(vhci, &vhci->ev_asynchronous); ++ if (vhci->ev_cq) ++ bce_destroy_cq(vhci->dev, vhci->ev_cq); ++} ++ ++static void bce_vhci_send_fw_event_response(struct bce_vhci *vhci, struct bce_vhci_message *req, u16 status) ++{ ++ unsigned long timeout = 1000; ++ struct bce_vhci_message r = *req; ++ r.cmd = (u16) (req->cmd | 0x8000u); ++ r.status = status; ++ r.param1 = req->param1; ++ r.param2 = 0; ++ ++ if (bce_reserve_submission(vhci->msg_system.sq, &timeout)) { ++ pr_err("bce-vhci: Cannot reserve submision for FW event reply\n"); ++ return; ++ } ++ bce_vhci_message_queue_write(&vhci->msg_system, &r); ++} ++ ++static int bce_vhci_handle_firmware_event(struct bce_vhci *vhci, struct bce_vhci_message *msg) ++{ ++ unsigned long flags; ++ bce_vhci_device_t devid; ++ u8 endp; ++ struct bce_vhci_device *dev; ++ struct bce_vhci_transfer_queue *tq; ++ if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE || msg->cmd == BCE_VHCI_CMD_ENDPOINT_SET_STATE) { ++ devid = (bce_vhci_device_t) (msg->param1 & 0xff); ++ endp = bce_vhci_endpoint_index((u8) ((msg->param1 >> 8) & 0xff)); ++ dev = vhci->devices[devid]; ++ if (!dev || !(dev->tq_mask & BIT(endp))) ++ return BCE_VHCI_BAD_ARGUMENT; ++ tq = &dev->tq[endp]; ++ } ++ ++ if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE) { ++ if (msg->param2 == BCE_VHCI_ENDPOINT_ACTIVE) { ++ bce_vhci_transfer_queue_resume(tq, BCE_VHCI_PAUSE_FIRMWARE); ++ return BCE_VHCI_SUCCESS; ++ } else if (msg->param2 == BCE_VHCI_ENDPOINT_PAUSED) { ++ bce_vhci_transfer_queue_pause(tq, BCE_VHCI_PAUSE_FIRMWARE); ++ return BCE_VHCI_SUCCESS; ++ } ++ return BCE_VHCI_BAD_ARGUMENT; ++ } else if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_SET_STATE) { ++ if (msg->param2 == BCE_VHCI_ENDPOINT_STALLED) { ++ tq->state = msg->param2; ++ spin_lock_irqsave(&tq->urb_lock, flags); ++ tq->stalled = true; ++ spin_unlock_irqrestore(&tq->urb_lock, flags); ++ return BCE_VHCI_SUCCESS; ++ } ++ return BCE_VHCI_BAD_ARGUMENT; ++ } ++ pr_warn("bce-vhci: Unhandled firmware event: %x s=%x p1=%x p2=%llx\n", ++ msg->cmd, msg->status, msg->param1, msg->param2); ++ return BCE_VHCI_BAD_ARGUMENT; ++} ++ ++static void bce_vhci_handle_firmware_events_w(struct work_struct *ws) ++{ ++ size_t cnt = 0; ++ int result; ++ struct bce_vhci *vhci = container_of(ws, struct bce_vhci, w_fw_events); ++ struct bce_queue_sq *sq = vhci->ev_commands.sq; ++ struct bce_sq_completion_data *cq; ++ struct bce_vhci_message *msg, *msg2 = NULL; ++ ++ while (true) { ++ if (msg2) { ++ msg = msg2; ++ msg2 = NULL; ++ } else if ((cq = bce_next_completion(sq))) { ++ if (cq->status == BCE_COMPLETION_ABORTED) { ++ bce_notify_submission_complete(sq); ++ continue; ++ } ++ msg = &vhci->ev_commands.data[sq->head]; ++ } else { ++ break; ++ } ++ ++ pr_debug("bce-vhci: Got fw event: %x s=%x p1=%x p2=%llx\n", msg->cmd, msg->status, msg->param1, msg->param2); ++ if ((cq = bce_next_completion(sq))) { ++ msg2 = &vhci->ev_commands.data[(sq->head + 1) % sq->el_count]; ++ pr_debug("bce-vhci: Got second fw event: %x s=%x p1=%x p2=%llx\n", ++ msg->cmd, msg->status, msg->param1, msg->param2); ++ if (cq->status != BCE_COMPLETION_ABORTED && ++ msg2->cmd == (msg->cmd | 0x4000) && msg2->param1 == msg->param1) { ++ /* Take two elements */ ++ pr_debug("bce-vhci: Cancelled\n"); ++ bce_vhci_send_fw_event_response(vhci, msg, BCE_VHCI_ABORT); ++ ++ bce_notify_submission_complete(sq); ++ bce_notify_submission_complete(sq); ++ msg2 = NULL; ++ cnt += 2; ++ continue; ++ } ++ ++ pr_warn("bce-vhci: Handle fw event - unexpected cancellation\n"); ++ } ++ ++ result = bce_vhci_handle_firmware_event(vhci, msg); ++ bce_vhci_send_fw_event_response(vhci, msg, (u16) result); ++ ++ ++ bce_notify_submission_complete(sq); ++ ++cnt; ++ } ++ bce_vhci_event_queue_submit_pending(&vhci->ev_commands, cnt); ++ if (atomic_read(&sq->available_commands) == sq->el_count - 1) { ++ pr_debug("bce-vhci: complete\n"); ++ complete(&vhci->ev_commands.queue_empty_completion); ++ } ++} ++ ++static void bce_vhci_firmware_event_completion(struct bce_queue_sq *sq) ++{ ++ struct bce_vhci_event_queue *q = sq->userdata; ++ queue_work(q->vhci->tq_state_wq, &q->vhci->w_fw_events); ++} ++ ++static void bce_vhci_handle_system_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg) ++{ ++ if (msg->cmd & 0x8000) { ++ bce_vhci_command_queue_deliver_completion(&q->vhci->cq, msg); ++ } else { ++ pr_warn("bce-vhci: Unhandled system event: %x s=%x p1=%x p2=%llx\n", ++ msg->cmd, msg->status, msg->param1, msg->param2); ++ } ++} ++ ++static void bce_vhci_handle_usb_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg) ++{ ++ bce_vhci_device_t devid; ++ u8 endp; ++ struct bce_vhci_device *dev; ++ if (msg->cmd & 0x8000) { ++ bce_vhci_command_queue_deliver_completion(&q->vhci->cq, msg); ++ } else if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST || msg->cmd == BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS) { ++ devid = (bce_vhci_device_t) (msg->param1 & 0xff); ++ endp = bce_vhci_endpoint_index((u8) ((msg->param1 >> 8) & 0xff)); ++ dev = q->vhci->devices[devid]; ++ if (!dev || (dev->tq_mask & BIT(endp)) == 0) { ++ pr_err("bce-vhci: Didn't find destination for transfer queue event\n"); ++ return; ++ } ++ bce_vhci_transfer_queue_event(&dev->tq[endp], msg); ++ } else { ++ pr_warn("bce-vhci: Unhandled USB event: %x s=%x p1=%x p2=%llx\n", ++ msg->cmd, msg->status, msg->param1, msg->param2); ++ } ++} ++ ++ ++ ++static const struct hc_driver bce_vhci_driver = { ++ .description = "bce-vhci", ++ .product_desc = "BCE VHCI Host Controller", ++ .hcd_priv_size = sizeof(struct bce_vhci *), ++ ++#if LINUX_VERSION_CODE < KERNEL_VERSION(5,4,0) ++ .flags = HCD_USB2, ++#else ++ .flags = HCD_USB2 | HCD_DMA, ++#endif ++ ++ .start = bce_vhci_start, ++ .stop = bce_vhci_stop, ++ .hub_status_data = bce_vhci_hub_status_data, ++ .hub_control = bce_vhci_hub_control, ++ .urb_enqueue = bce_vhci_urb_enqueue, ++ .urb_dequeue = bce_vhci_urb_dequeue, ++ .enable_device = bce_vhci_enable_device, ++ .free_dev = bce_vhci_free_device, ++ .address_device = bce_vhci_address_device, ++ .add_endpoint = bce_vhci_add_endpoint, ++ .drop_endpoint = bce_vhci_drop_endpoint, ++ .endpoint_reset = bce_vhci_endpoint_reset, ++ .check_bandwidth = bce_vhci_check_bandwidth, ++ .get_frame_number = bce_vhci_get_frame_number, ++ .bus_suspend = bce_vhci_bus_suspend, ++ .bus_resume = bce_vhci_bus_resume ++}; ++ ++ ++int __init bce_vhci_module_init(void) ++{ ++ int result; ++ if ((result = alloc_chrdev_region(&bce_vhci_chrdev, 0, 1, "bce-vhci"))) ++ goto fail_chrdev; ++#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) ++ bce_vhci_class = class_create(THIS_MODULE, "bce-vhci"); ++#else ++ bce_vhci_class = class_create("bce-vhci"); ++#endif ++ if (IS_ERR(bce_vhci_class)) { ++ result = PTR_ERR(bce_vhci_class); ++ goto fail_class; ++ } ++ return 0; ++ ++fail_class: ++ class_destroy(bce_vhci_class); ++fail_chrdev: ++ unregister_chrdev_region(bce_vhci_chrdev, 1); ++ if (!result) ++ result = -EINVAL; ++ return result; ++} ++void __exit bce_vhci_module_exit(void) ++{ ++ class_destroy(bce_vhci_class); ++ unregister_chrdev_region(bce_vhci_chrdev, 1); ++} ++ ++module_param_named(vhci_port_mask, bce_vhci_port_mask, ushort, 0444); ++MODULE_PARM_DESC(vhci_port_mask, "Specifies which VHCI ports are enabled"); +diff --git a/drivers/staging/apple-bce/vhci/vhci.h b/drivers/staging/apple-bce/vhci/vhci.h +new file mode 100644 +index 000000000000..6c2e22622f4c +--- /dev/null ++++ b/drivers/staging/apple-bce/vhci/vhci.h +@@ -0,0 +1,52 @@ ++#ifndef BCE_VHCI_H ++#define BCE_VHCI_H ++ ++#include "queue.h" ++#include "transfer.h" ++ ++struct usb_hcd; ++struct bce_queue_cq; ++ ++struct bce_vhci_device { ++ struct bce_vhci_transfer_queue tq[32]; ++ u32 tq_mask; ++}; ++struct bce_vhci { ++ struct apple_bce_device *dev; ++ dev_t vdevt; ++ struct device *vdev; ++ struct usb_hcd *hcd; ++ struct spinlock hcd_spinlock; ++ struct bce_vhci_message_queue msg_commands; ++ struct bce_vhci_message_queue msg_system; ++ struct bce_vhci_message_queue msg_isochronous; ++ struct bce_vhci_message_queue msg_interrupt; ++ struct bce_vhci_message_queue msg_asynchronous; ++ struct spinlock msg_asynchronous_lock; ++ struct bce_vhci_command_queue cq; ++ struct bce_queue_cq *ev_cq; ++ struct bce_vhci_event_queue ev_commands; ++ struct bce_vhci_event_queue ev_system; ++ struct bce_vhci_event_queue ev_isochronous; ++ struct bce_vhci_event_queue ev_interrupt; ++ struct bce_vhci_event_queue ev_asynchronous; ++ u16 port_mask; ++ u8 port_count; ++ u16 port_power_mask; ++ bce_vhci_device_t port_to_device[16]; ++ struct bce_vhci_device *devices[16]; ++ struct workqueue_struct *tq_state_wq; ++ struct work_struct w_fw_events; ++}; ++ ++int __init bce_vhci_module_init(void); ++void __exit bce_vhci_module_exit(void); ++ ++int bce_vhci_create(struct apple_bce_device *dev, struct bce_vhci *vhci); ++void bce_vhci_destroy(struct bce_vhci *vhci); ++int bce_vhci_start(struct usb_hcd *hcd); ++void bce_vhci_stop(struct usb_hcd *hcd); ++ ++struct bce_vhci *bce_vhci_from_hcd(struct usb_hcd *hcd); ++ ++#endif //BCE_VHCI_H +diff --git a/drivers/usb/core/driver.c b/drivers/usb/core/driver.c +index e02ba15f6e34..b35734d03109 100644 +--- a/drivers/usb/core/driver.c ++++ b/drivers/usb/core/driver.c +@@ -517,6 +517,19 @@ static int usb_unbind_interface(struct device *dev) + return 0; + } + ++static void usb_shutdown_interface(struct device *dev) ++{ ++ struct usb_interface *intf = to_usb_interface(dev); ++ struct usb_driver *driver; ++ ++ if (!dev->driver) ++ return; ++ ++ driver = to_usb_driver(dev->driver); ++ if (driver->shutdown) ++ driver->shutdown(intf); ++} ++ + /** + * usb_driver_claim_interface - bind a driver to an interface + * @driver: the driver to be bound +@@ -1059,6 +1072,7 @@ int usb_register_driver(struct usb_driver *new_driver, struct module *owner, + new_driver->driver.bus = &usb_bus_type; + new_driver->driver.probe = usb_probe_interface; + new_driver->driver.remove = usb_unbind_interface; ++ new_driver->driver.shutdown = usb_shutdown_interface; + new_driver->driver.owner = owner; + new_driver->driver.mod_name = mod_name; + new_driver->driver.dev_groups = new_driver->dev_groups; +diff --git a/drivers/usb/storage/uas.c b/drivers/usb/storage/uas.c +index b610a2de4ae5..0cdbcf82554f 100644 +--- a/drivers/usb/storage/uas.c ++++ b/drivers/usb/storage/uas.c +@@ -1232,9 +1232,8 @@ static void uas_disconnect(struct usb_interface *intf) + * hang on reboot when the device is still in uas mode. Note the reset is + * necessary as some devices won't revert to usb-storage mode without it. + */ +-static void uas_shutdown(struct device *dev) ++static void uas_shutdown(struct usb_interface *intf) + { +- struct usb_interface *intf = to_usb_interface(dev); + struct usb_device *udev = interface_to_usbdev(intf); + struct Scsi_Host *shost = usb_get_intfdata(intf); + struct uas_dev_info *devinfo = (struct uas_dev_info *)shost->hostdata; +@@ -1257,7 +1256,7 @@ static struct usb_driver uas_driver = { + .suspend = uas_suspend, + .resume = uas_resume, + .reset_resume = uas_reset_resume, +- .driver.shutdown = uas_shutdown, ++ .shutdown = uas_shutdown, + .id_table = uas_usb_ids, + }; + +diff --git a/include/drm/drm_format_helper.h b/include/drm/drm_format_helper.h +index 428d81afe215..aa1604d92c1a 100644 +--- a/include/drm/drm_format_helper.h ++++ b/include/drm/drm_format_helper.h +@@ -96,6 +96,9 @@ void drm_fb_xrgb8888_to_rgba5551(struct iosys_map *dst, const unsigned int *dst_ + void drm_fb_xrgb8888_to_rgb888(struct iosys_map *dst, const unsigned int *dst_pitch, + const struct iosys_map *src, const struct drm_framebuffer *fb, + const struct drm_rect *clip, struct drm_format_conv_state *state); ++void drm_fb_xrgb8888_to_bgr888(struct iosys_map *dst, const unsigned int *dst_pitch, ++ const struct iosys_map *src, const struct drm_framebuffer *fb, ++ const struct drm_rect *clip, struct drm_format_conv_state *state); + void drm_fb_xrgb8888_to_argb8888(struct iosys_map *dst, const unsigned int *dst_pitch, + const struct iosys_map *src, const struct drm_framebuffer *fb, + const struct drm_rect *clip, struct drm_format_conv_state *state); +diff --git a/include/linux/efi.h b/include/linux/efi.h +index 418e555459da..3a6c04a9f9aa 100644 +--- a/include/linux/efi.h ++++ b/include/linux/efi.h +@@ -74,10 +74,10 @@ typedef void *efi_handle_t; + */ + typedef guid_t efi_guid_t __aligned(__alignof__(u32)); + +-#define EFI_GUID(a, b, c, d...) (efi_guid_t){ { \ ++#define EFI_GUID(a, b, c, d...) ((efi_guid_t){ { \ + (a) & 0xff, ((a) >> 8) & 0xff, ((a) >> 16) & 0xff, ((a) >> 24) & 0xff, \ + (b) & 0xff, ((b) >> 8) & 0xff, \ +- (c) & 0xff, ((c) >> 8) & 0xff, d } } ++ (c) & 0xff, ((c) >> 8) & 0xff, d } }) + + /* + * Generic EFI table header +@@ -385,6 +385,7 @@ void efi_native_runtime_setup(void); + #define EFI_MEMORY_ATTRIBUTES_TABLE_GUID EFI_GUID(0xdcfa911d, 0x26eb, 0x469f, 0xa2, 0x20, 0x38, 0xb7, 0xdc, 0x46, 0x12, 0x20) + #define EFI_CONSOLE_OUT_DEVICE_GUID EFI_GUID(0xd3b36f2c, 0xd551, 0x11d4, 0x9a, 0x46, 0x00, 0x90, 0x27, 0x3f, 0xc1, 0x4d) + #define APPLE_PROPERTIES_PROTOCOL_GUID EFI_GUID(0x91bd12fe, 0xf6c3, 0x44fb, 0xa5, 0xb7, 0x51, 0x22, 0xab, 0x30, 0x3a, 0xe0) ++#define APPLE_SET_OS_PROTOCOL_GUID EFI_GUID(0xc5c5da95, 0x7d5c, 0x45e6, 0xb2, 0xf1, 0x3f, 0xd5, 0x2b, 0xb1, 0x00, 0x77) + #define EFI_TCG2_PROTOCOL_GUID EFI_GUID(0x607f766c, 0x7455, 0x42be, 0x93, 0x0b, 0xe4, 0xd7, 0x6d, 0xb2, 0x72, 0x0f) + #define EFI_TCG2_FINAL_EVENTS_TABLE_GUID EFI_GUID(0x1e2ed096, 0x30e2, 0x4254, 0xbd, 0x89, 0x86, 0x3b, 0xbe, 0xf8, 0x23, 0x25) + #define EFI_LOAD_FILE_PROTOCOL_GUID EFI_GUID(0x56ec3091, 0x954c, 0x11d2, 0x8e, 0x3f, 0x00, 0xa0, 0xc9, 0x69, 0x72, 0x3b) +diff --git a/include/linux/hid.h b/include/linux/hid.h +index 8e06d89698e6..6cdb5a451453 100644 +--- a/include/linux/hid.h ++++ b/include/linux/hid.h +@@ -940,6 +940,8 @@ extern void hidinput_report_event(struct hid_device *hid, struct hid_report *rep + extern int hidinput_connect(struct hid_device *hid, unsigned int force); + extern void hidinput_disconnect(struct hid_device *); + ++struct hid_field *hid_find_field(struct hid_device *hdev, unsigned int report_type, ++ unsigned int application, unsigned int usage); + int hid_set_field(struct hid_field *, unsigned, __s32); + int hid_input_report(struct hid_device *hid, enum hid_report_type type, u8 *data, u32 size, + int interrupt); +diff --git a/include/linux/usb.h b/include/linux/usb.h +index 1913a13833f2..832997a9da0a 100644 +--- a/include/linux/usb.h ++++ b/include/linux/usb.h +@@ -1171,6 +1171,7 @@ extern ssize_t usb_show_dynids(struct usb_dynids *dynids, char *buf); + * post_reset method is called. + * @post_reset: Called by usb_reset_device() after the device + * has been reset ++ * @shutdown: Called at shut-down time to quiesce the device. + * @id_table: USB drivers use ID table to support hotplugging. + * Export this with MODULE_DEVICE_TABLE(usb,...). This must be set + * or your driver's probe function will never get called. +@@ -1222,6 +1223,8 @@ struct usb_driver { + int (*pre_reset)(struct usb_interface *intf); + int (*post_reset)(struct usb_interface *intf); + ++ void (*shutdown)(struct usb_interface *intf); ++ + const struct usb_device_id *id_table; + const struct attribute_group **dev_groups; + +diff --git a/lib/test_printf.c b/lib/test_printf.c +index 69b6a5e177f2..a318bb72a165 100644 +--- a/lib/test_printf.c ++++ b/lib/test_printf.c +@@ -745,18 +745,26 @@ static void __init fwnode_pointer(void) + static void __init fourcc_pointer(void) + { + struct { ++ char type; + u32 code; + char *str; + } const try[] = { +- { 0x3231564e, "NV12 little-endian (0x3231564e)", }, +- { 0xb231564e, "NV12 big-endian (0xb231564e)", }, +- { 0x10111213, ".... little-endian (0x10111213)", }, +- { 0x20303159, "Y10 little-endian (0x20303159)", }, ++ { 'c', 0x3231564e, "NV12 little-endian (0x3231564e)", }, ++ { 'c', 0xb231564e, "NV12 big-endian (0xb231564e)", }, ++ { 'c', 0x10111213, ".... little-endian (0x10111213)", }, ++ { 'c', 0x20303159, "Y10 little-endian (0x20303159)", }, ++ { 'h', 0x67503030, "gP00 (0x67503030)", }, ++ { 'r', 0x30305067, "gP00 (0x67503030)", }, ++ { 'l', cpu_to_le32(0x67503030), "gP00 (0x67503030)", }, ++ { 'b', cpu_to_be32(0x67503030), "gP00 (0x67503030)", }, + }; + unsigned int i; + +- for (i = 0; i < ARRAY_SIZE(try); i++) +- test(try[i].str, "%p4cc", &try[i].code); ++ for (i = 0; i < ARRAY_SIZE(try); i++) { ++ char fmt[] = { '%', 'p', '4', 'c', try[i].type, '\0' }; ++ ++ test(try[i].str, fmt, &try[i].code); ++ } + } + + static void __init +diff --git a/lib/vsprintf.c b/lib/vsprintf.c +index cdd4e2314bfc..4feaea1815fa 100644 +--- a/lib/vsprintf.c ++++ b/lib/vsprintf.c +@@ -1760,27 +1760,50 @@ char *fourcc_string(char *buf, char *end, const u32 *fourcc, + char output[sizeof("0123 little-endian (0x01234567)")]; + char *p = output; + unsigned int i; ++ bool pix_fmt = false; + u32 orig, val; + +- if (fmt[1] != 'c' || fmt[2] != 'c') ++ if (fmt[1] != 'c') + return error_string(buf, end, "(%p4?)", spec); + + if (check_pointer(&buf, end, fourcc, spec)) + return buf; + + orig = get_unaligned(fourcc); +- val = orig & ~BIT(31); ++ switch (fmt[2]) { ++ case 'h': ++ val = orig; ++ break; ++ case 'r': ++ val = orig = swab32(orig); ++ break; ++ case 'l': ++ val = orig = le32_to_cpu(orig); ++ break; ++ case 'b': ++ val = orig = be32_to_cpu(orig); ++ break; ++ case 'c': ++ /* Pixel formats are printed LSB-first */ ++ val = swab32(orig & ~BIT(31)); ++ pix_fmt = true; ++ break; ++ default: ++ return error_string(buf, end, "(%p4?)", spec); ++ } + + for (i = 0; i < sizeof(u32); i++) { +- unsigned char c = val >> (i * 8); ++ unsigned char c = val >> ((3 - i) * 8); + + /* Print non-control ASCII characters as-is, dot otherwise */ + *p++ = isascii(c) && isprint(c) ? c : '.'; + } + +- *p++ = ' '; +- strcpy(p, orig & BIT(31) ? "big-endian" : "little-endian"); +- p += strlen(p); ++ if (pix_fmt) { ++ *p++ = ' '; ++ strcpy(p, orig & BIT(31) ? "big-endian" : "little-endian"); ++ p += strlen(p); ++ } + + *p++ = ' '; + *p++ = '('; +@@ -2355,6 +2378,7 @@ char *rust_fmt_argument(char *buf, char *end, void *ptr); + * read the documentation (path below) first. + * - 'NF' For a netdev_features_t + * - '4cc' V4L2 or DRM FourCC code, with endianness and raw numerical value. ++ * - '4c[hlbr]' Generic FourCC code. + * - 'h[CDN]' For a variable-length buffer, it prints it as a hex string with + * a certain separator (' ' by default): + * C colon +diff --git a/scripts/checkpatch.pl b/scripts/checkpatch.pl +index 2b812210b412..4c3a8cc6ef15 100755 +--- a/scripts/checkpatch.pl ++++ b/scripts/checkpatch.pl +@@ -6909,7 +6909,7 @@ sub process { + ($extension eq "f" && + defined $qualifier && $qualifier !~ /^w/) || + ($extension eq "4" && +- defined $qualifier && $qualifier !~ /^cc/)) { ++ defined $qualifier && $qualifier !~ /^c[chlbr]/)) { + $bad_specifier = $specifier; + last; + } +-- +2.46.0 + +From 22944ae0c983c003b10a918fef8a578e6d7689db Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Wed, 14 Aug 2024 16:20:35 +0200 +Subject: [PATCH 12/12] zstd + +Signed-off-by: Peter Jung +--- + include/linux/zstd.h | 2 +- + include/linux/zstd_errors.h | 23 +- + include/linux/zstd_lib.h | 850 +++++-- + lib/zstd/Makefile | 2 +- + lib/zstd/common/allocations.h | 56 + + lib/zstd/common/bits.h | 149 ++ + lib/zstd/common/bitstream.h | 127 +- + lib/zstd/common/compiler.h | 134 +- + lib/zstd/common/cpu.h | 3 +- + lib/zstd/common/debug.c | 9 +- + lib/zstd/common/debug.h | 34 +- + lib/zstd/common/entropy_common.c | 42 +- + lib/zstd/common/error_private.c | 12 +- + lib/zstd/common/error_private.h | 84 +- + lib/zstd/common/fse.h | 94 +- + lib/zstd/common/fse_decompress.c | 130 +- + lib/zstd/common/huf.h | 237 +- + lib/zstd/common/mem.h | 3 +- + lib/zstd/common/portability_macros.h | 28 +- + lib/zstd/common/zstd_common.c | 38 +- + lib/zstd/common/zstd_deps.h | 16 +- + lib/zstd/common/zstd_internal.h | 109 +- + lib/zstd/compress/clevels.h | 3 +- + lib/zstd/compress/fse_compress.c | 74 +- + lib/zstd/compress/hist.c | 3 +- + lib/zstd/compress/hist.h | 3 +- + lib/zstd/compress/huf_compress.c | 441 ++-- + lib/zstd/compress/zstd_compress.c | 2111 ++++++++++++----- + lib/zstd/compress/zstd_compress_internal.h | 359 ++- + lib/zstd/compress/zstd_compress_literals.c | 155 +- + lib/zstd/compress/zstd_compress_literals.h | 25 +- + lib/zstd/compress/zstd_compress_sequences.c | 7 +- + lib/zstd/compress/zstd_compress_sequences.h | 3 +- + lib/zstd/compress/zstd_compress_superblock.c | 376 ++- + lib/zstd/compress/zstd_compress_superblock.h | 3 +- + lib/zstd/compress/zstd_cwksp.h | 169 +- + lib/zstd/compress/zstd_double_fast.c | 143 +- + lib/zstd/compress/zstd_double_fast.h | 17 +- + lib/zstd/compress/zstd_fast.c | 596 +++-- + lib/zstd/compress/zstd_fast.h | 6 +- + lib/zstd/compress/zstd_lazy.c | 732 +++--- + lib/zstd/compress/zstd_lazy.h | 138 +- + lib/zstd/compress/zstd_ldm.c | 21 +- + lib/zstd/compress/zstd_ldm.h | 3 +- + lib/zstd/compress/zstd_ldm_geartab.h | 3 +- + lib/zstd/compress/zstd_opt.c | 497 ++-- + lib/zstd/compress/zstd_opt.h | 41 +- + lib/zstd/decompress/huf_decompress.c | 887 ++++--- + lib/zstd/decompress/zstd_ddict.c | 9 +- + lib/zstd/decompress/zstd_ddict.h | 3 +- + lib/zstd/decompress/zstd_decompress.c | 358 ++- + lib/zstd/decompress/zstd_decompress_block.c | 708 +++--- + lib/zstd/decompress/zstd_decompress_block.h | 10 +- + .../decompress/zstd_decompress_internal.h | 9 +- + lib/zstd/decompress_sources.h | 2 +- + lib/zstd/zstd_common_module.c | 5 +- + lib/zstd/zstd_compress_module.c | 2 +- + lib/zstd/zstd_decompress_module.c | 4 +- + 58 files changed, 6577 insertions(+), 3531 deletions(-) + create mode 100644 lib/zstd/common/allocations.h + create mode 100644 lib/zstd/common/bits.h + +diff --git a/include/linux/zstd.h b/include/linux/zstd.h +index 113408eef6ec..f109d49f43f8 100644 +--- a/include/linux/zstd.h ++++ b/include/linux/zstd.h +@@ -1,6 +1,6 @@ + /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/include/linux/zstd_errors.h b/include/linux/zstd_errors.h +index 58b6dd45a969..6d5cf55f0bf3 100644 +--- a/include/linux/zstd_errors.h ++++ b/include/linux/zstd_errors.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -17,8 +18,17 @@ + + + /* ===== ZSTDERRORLIB_API : control library symbols visibility ===== */ +-#define ZSTDERRORLIB_VISIBILITY +-#define ZSTDERRORLIB_API ZSTDERRORLIB_VISIBILITY ++#define ZSTDERRORLIB_VISIBLE ++ ++#ifndef ZSTDERRORLIB_HIDDEN ++# if (__GNUC__ >= 4) && !defined(__MINGW32__) ++# define ZSTDERRORLIB_HIDDEN __attribute__ ((visibility ("hidden"))) ++# else ++# define ZSTDERRORLIB_HIDDEN ++# endif ++#endif ++ ++#define ZSTDERRORLIB_API ZSTDERRORLIB_VISIBLE + + /*-********************************************* + * Error codes list +@@ -43,14 +53,17 @@ typedef enum { + ZSTD_error_frameParameter_windowTooLarge = 16, + ZSTD_error_corruption_detected = 20, + ZSTD_error_checksum_wrong = 22, ++ ZSTD_error_literals_headerWrong = 24, + ZSTD_error_dictionary_corrupted = 30, + ZSTD_error_dictionary_wrong = 32, + ZSTD_error_dictionaryCreation_failed = 34, + ZSTD_error_parameter_unsupported = 40, ++ ZSTD_error_parameter_combination_unsupported = 41, + ZSTD_error_parameter_outOfBound = 42, + ZSTD_error_tableLog_tooLarge = 44, + ZSTD_error_maxSymbolValue_tooLarge = 46, + ZSTD_error_maxSymbolValue_tooSmall = 48, ++ ZSTD_error_stabilityCondition_notRespected = 50, + ZSTD_error_stage_wrong = 60, + ZSTD_error_init_missing = 62, + ZSTD_error_memory_allocation = 64, +@@ -58,11 +71,15 @@ typedef enum { + ZSTD_error_dstSize_tooSmall = 70, + ZSTD_error_srcSize_wrong = 72, + ZSTD_error_dstBuffer_null = 74, ++ ZSTD_error_noForwardProgress_destFull = 80, ++ ZSTD_error_noForwardProgress_inputEmpty = 82, + /* following error codes are __NOT STABLE__, they can be removed or changed in future versions */ + ZSTD_error_frameIndex_tooLarge = 100, + ZSTD_error_seekableIO = 102, + ZSTD_error_dstBuffer_wrong = 104, + ZSTD_error_srcBuffer_wrong = 105, ++ ZSTD_error_sequenceProducer_failed = 106, ++ ZSTD_error_externalSequences_invalid = 107, + ZSTD_error_maxCode = 120 /* never EVER use this value directly, it can change in future versions! Use ZSTD_isError() instead */ + } ZSTD_ErrorCode; + +diff --git a/include/linux/zstd_lib.h b/include/linux/zstd_lib.h +index 79d55465d5c1..6320fedcf8a4 100644 +--- a/include/linux/zstd_lib.h ++++ b/include/linux/zstd_lib.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -11,23 +12,42 @@ + #ifndef ZSTD_H_235446 + #define ZSTD_H_235446 + +-/* ====== Dependency ======*/ ++/* ====== Dependencies ======*/ + #include /* INT_MAX */ + #include /* size_t */ + + + /* ===== ZSTDLIB_API : control library symbols visibility ===== */ +-#ifndef ZSTDLIB_VISIBLE ++#define ZSTDLIB_VISIBLE ++ ++#ifndef ZSTDLIB_HIDDEN + # if (__GNUC__ >= 4) && !defined(__MINGW32__) +-# define ZSTDLIB_VISIBLE __attribute__ ((visibility ("default"))) + # define ZSTDLIB_HIDDEN __attribute__ ((visibility ("hidden"))) + # else +-# define ZSTDLIB_VISIBLE + # define ZSTDLIB_HIDDEN + # endif + #endif ++ + #define ZSTDLIB_API ZSTDLIB_VISIBLE + ++/* Deprecation warnings : ++ * Should these warnings be a problem, it is generally possible to disable them, ++ * typically with -Wno-deprecated-declarations for gcc or _CRT_SECURE_NO_WARNINGS in Visual. ++ * Otherwise, it's also possible to define ZSTD_DISABLE_DEPRECATE_WARNINGS. ++ */ ++#ifdef ZSTD_DISABLE_DEPRECATE_WARNINGS ++# define ZSTD_DEPRECATED(message) /* disable deprecation warnings */ ++#else ++# if (defined(GNUC) && (GNUC > 4 || (GNUC == 4 && GNUC_MINOR >= 5))) || defined(__clang__) ++# define ZSTD_DEPRECATED(message) __attribute__((deprecated(message))) ++# elif (__GNUC__ >= 3) ++# define ZSTD_DEPRECATED(message) __attribute__((deprecated)) ++# else ++# pragma message("WARNING: You need to implement ZSTD_DEPRECATED for this compiler") ++# define ZSTD_DEPRECATED(message) ++# endif ++#endif /* ZSTD_DISABLE_DEPRECATE_WARNINGS */ ++ + + /* ***************************************************************************** + Introduction +@@ -65,7 +85,7 @@ + /*------ Version ------*/ + #define ZSTD_VERSION_MAJOR 1 + #define ZSTD_VERSION_MINOR 5 +-#define ZSTD_VERSION_RELEASE 2 ++#define ZSTD_VERSION_RELEASE 6 + #define ZSTD_VERSION_NUMBER (ZSTD_VERSION_MAJOR *100*100 + ZSTD_VERSION_MINOR *100 + ZSTD_VERSION_RELEASE) + + /*! ZSTD_versionNumber() : +@@ -107,7 +127,8 @@ ZSTDLIB_API const char* ZSTD_versionString(void); + ***************************************/ + /*! ZSTD_compress() : + * Compresses `src` content as a single zstd compressed frame into already allocated `dst`. +- * Hint : compression runs faster if `dstCapacity` >= `ZSTD_compressBound(srcSize)`. ++ * NOTE: Providing `dstCapacity >= ZSTD_compressBound(srcSize)` guarantees that zstd will have ++ * enough space to successfully compress the data. + * @return : compressed size written into `dst` (<= `dstCapacity), + * or an error code if it fails (which can be tested using ZSTD_isError()). */ + ZSTDLIB_API size_t ZSTD_compress( void* dst, size_t dstCapacity, +@@ -156,7 +177,9 @@ ZSTDLIB_API unsigned long long ZSTD_getFrameContentSize(const void *src, size_t + * "empty", "unknown" and "error" results to the same return value (0), + * while ZSTD_getFrameContentSize() gives them separate return values. + * @return : decompressed size of `src` frame content _if known and not empty_, 0 otherwise. */ +-ZSTDLIB_API unsigned long long ZSTD_getDecompressedSize(const void* src, size_t srcSize); ++ZSTD_DEPRECATED("Replaced by ZSTD_getFrameContentSize") ++ZSTDLIB_API ++unsigned long long ZSTD_getDecompressedSize(const void* src, size_t srcSize); + + /*! ZSTD_findFrameCompressedSize() : Requires v1.4.0+ + * `src` should point to the start of a ZSTD frame or skippable frame. +@@ -168,8 +191,30 @@ ZSTDLIB_API size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize) + + + /*====== Helper functions ======*/ +-#define ZSTD_COMPRESSBOUND(srcSize) ((srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0)) /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */ +-ZSTDLIB_API size_t ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */ ++/* ZSTD_compressBound() : ++ * maximum compressed size in worst case single-pass scenario. ++ * When invoking `ZSTD_compress()` or any other one-pass compression function, ++ * it's recommended to provide @dstCapacity >= ZSTD_compressBound(srcSize) ++ * as it eliminates one potential failure scenario, ++ * aka not enough room in dst buffer to write the compressed frame. ++ * Note : ZSTD_compressBound() itself can fail, if @srcSize > ZSTD_MAX_INPUT_SIZE . ++ * In which case, ZSTD_compressBound() will return an error code ++ * which can be tested using ZSTD_isError(). ++ * ++ * ZSTD_COMPRESSBOUND() : ++ * same as ZSTD_compressBound(), but as a macro. ++ * It can be used to produce constants, which can be useful for static allocation, ++ * for example to size a static array on stack. ++ * Will produce constant value 0 if srcSize too large. ++ */ ++#define ZSTD_MAX_INPUT_SIZE ((sizeof(size_t)==8) ? 0xFF00FF00FF00FF00ULL : 0xFF00FF00U) ++#define ZSTD_COMPRESSBOUND(srcSize) (((size_t)(srcSize) >= ZSTD_MAX_INPUT_SIZE) ? 0 : (srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0)) /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */ ++ZSTDLIB_API size_t ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */ ++/* ZSTD_isError() : ++ * Most ZSTD_* functions returning a size_t value can be tested for error, ++ * using ZSTD_isError(). ++ * @return 1 if error, 0 otherwise ++ */ + ZSTDLIB_API unsigned ZSTD_isError(size_t code); /*!< tells if a `size_t` function result is an error code */ + ZSTDLIB_API const char* ZSTD_getErrorName(size_t code); /*!< provides readable string from an error code */ + ZSTDLIB_API int ZSTD_minCLevel(void); /*!< minimum negative compression level allowed, requires v1.4.0+ */ +@@ -183,7 +228,7 @@ ZSTDLIB_API int ZSTD_defaultCLevel(void); /*!< default compres + /*= Compression context + * When compressing many times, + * it is recommended to allocate a context just once, +- * and re-use it for each successive compression operation. ++ * and reuse it for each successive compression operation. + * This will make workload friendlier for system's memory. + * Note : re-using context is just a speed / resource optimization. + * It doesn't change the compression ratio, which remains identical. +@@ -196,9 +241,9 @@ ZSTDLIB_API size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx); /* accept NULL pointer * + + /*! ZSTD_compressCCtx() : + * Same as ZSTD_compress(), using an explicit ZSTD_CCtx. +- * Important : in order to behave similarly to `ZSTD_compress()`, +- * this function compresses at requested compression level, +- * __ignoring any other parameter__ . ++ * Important : in order to mirror `ZSTD_compress()` behavior, ++ * this function compresses at the requested compression level, ++ * __ignoring any other advanced parameter__ . + * If any advanced parameter was set using the advanced API, + * they will all be reset. Only `compressionLevel` remains. + */ +@@ -210,7 +255,7 @@ ZSTDLIB_API size_t ZSTD_compressCCtx(ZSTD_CCtx* cctx, + /*= Decompression context + * When decompressing many times, + * it is recommended to allocate a context only once, +- * and re-use it for each successive compression operation. ++ * and reuse it for each successive compression operation. + * This will make workload friendlier for system's memory. + * Use one context per thread for parallel execution. */ + typedef struct ZSTD_DCtx_s ZSTD_DCtx; +@@ -220,7 +265,7 @@ ZSTDLIB_API size_t ZSTD_freeDCtx(ZSTD_DCtx* dctx); /* accept NULL pointer * + /*! ZSTD_decompressDCtx() : + * Same as ZSTD_decompress(), + * requires an allocated ZSTD_DCtx. +- * Compatible with sticky parameters. ++ * Compatible with sticky parameters (see below). + */ + ZSTDLIB_API size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, +@@ -236,12 +281,12 @@ ZSTDLIB_API size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, + * using ZSTD_CCtx_set*() functions. + * Pushed parameters are sticky : they are valid for next compressed frame, and any subsequent frame. + * "sticky" parameters are applicable to `ZSTD_compress2()` and `ZSTD_compressStream*()` ! +- * __They do not apply to "simple" one-shot variants such as ZSTD_compressCCtx()__ . ++ * __They do not apply to one-shot variants such as ZSTD_compressCCtx()__ . + * + * It's possible to reset all parameters to "default" using ZSTD_CCtx_reset(). + * + * This API supersedes all other "advanced" API entry points in the experimental section. +- * In the future, we expect to remove from experimental API entry points which are redundant with this API. ++ * In the future, we expect to remove API entry points from experimental which are redundant with this API. + */ + + +@@ -324,6 +369,19 @@ typedef enum { + * The higher the value of selected strategy, the more complex it is, + * resulting in stronger and slower compression. + * Special: value 0 means "use default strategy". */ ++ ++ ZSTD_c_targetCBlockSize=130, /* v1.5.6+ ++ * Attempts to fit compressed block size into approximatively targetCBlockSize. ++ * Bound by ZSTD_TARGETCBLOCKSIZE_MIN and ZSTD_TARGETCBLOCKSIZE_MAX. ++ * Note that it's not a guarantee, just a convergence target (default:0). ++ * No target when targetCBlockSize == 0. ++ * This is helpful in low bandwidth streaming environments to improve end-to-end latency, ++ * when a client can make use of partial documents (a prominent example being Chrome). ++ * Note: this parameter is stable since v1.5.6. ++ * It was present as an experimental parameter in earlier versions, ++ * but it's not recommended using it with earlier library versions ++ * due to massive performance regressions. ++ */ + /* LDM mode parameters */ + ZSTD_c_enableLongDistanceMatching=160, /* Enable long distance matching. + * This parameter is designed to improve compression ratio +@@ -403,7 +461,6 @@ typedef enum { + * ZSTD_c_forceMaxWindow + * ZSTD_c_forceAttachDict + * ZSTD_c_literalCompressionMode +- * ZSTD_c_targetCBlockSize + * ZSTD_c_srcSizeHint + * ZSTD_c_enableDedicatedDictSearch + * ZSTD_c_stableInBuffer +@@ -412,6 +469,9 @@ typedef enum { + * ZSTD_c_validateSequences + * ZSTD_c_useBlockSplitter + * ZSTD_c_useRowMatchFinder ++ * ZSTD_c_prefetchCDictTables ++ * ZSTD_c_enableSeqProducerFallback ++ * ZSTD_c_maxBlockSize + * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them. + * note : never ever use experimentalParam? names directly; + * also, the enums values themselves are unstable and can still change. +@@ -421,7 +481,7 @@ typedef enum { + ZSTD_c_experimentalParam3=1000, + ZSTD_c_experimentalParam4=1001, + ZSTD_c_experimentalParam5=1002, +- ZSTD_c_experimentalParam6=1003, ++ /* was ZSTD_c_experimentalParam6=1003; is now ZSTD_c_targetCBlockSize */ + ZSTD_c_experimentalParam7=1004, + ZSTD_c_experimentalParam8=1005, + ZSTD_c_experimentalParam9=1006, +@@ -430,7 +490,11 @@ typedef enum { + ZSTD_c_experimentalParam12=1009, + ZSTD_c_experimentalParam13=1010, + ZSTD_c_experimentalParam14=1011, +- ZSTD_c_experimentalParam15=1012 ++ ZSTD_c_experimentalParam15=1012, ++ ZSTD_c_experimentalParam16=1013, ++ ZSTD_c_experimentalParam17=1014, ++ ZSTD_c_experimentalParam18=1015, ++ ZSTD_c_experimentalParam19=1016 + } ZSTD_cParameter; + + typedef struct { +@@ -493,7 +557,7 @@ typedef enum { + * They will be used to compress next frame. + * Resetting session never fails. + * - The parameters : changes all parameters back to "default". +- * This removes any reference to any dictionary too. ++ * This also removes any reference to any dictionary or external sequence producer. + * Parameters can only be changed between 2 sessions (i.e. no compression is currently ongoing) + * otherwise the reset fails, and function returns an error value (which can be tested using ZSTD_isError()) + * - Both : similar to resetting the session, followed by resetting parameters. +@@ -502,11 +566,13 @@ ZSTDLIB_API size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset); + + /*! ZSTD_compress2() : + * Behave the same as ZSTD_compressCCtx(), but compression parameters are set using the advanced API. ++ * (note that this entry point doesn't even expose a compression level parameter). + * ZSTD_compress2() always starts a new frame. + * Should cctx hold data from a previously unfinished frame, everything about it is forgotten. + * - Compression parameters are pushed into CCtx before starting compression, using ZSTD_CCtx_set*() + * - The function is always blocking, returns when compression is completed. +- * Hint : compression runs faster if `dstCapacity` >= `ZSTD_compressBound(srcSize)`. ++ * NOTE: Providing `dstCapacity >= ZSTD_compressBound(srcSize)` guarantees that zstd will have ++ * enough space to successfully compress the data, though it is possible it fails for other reasons. + * @return : compressed size written into `dst` (<= `dstCapacity), + * or an error code if it fails (which can be tested using ZSTD_isError()). + */ +@@ -543,13 +609,17 @@ typedef enum { + * ZSTD_d_stableOutBuffer + * ZSTD_d_forceIgnoreChecksum + * ZSTD_d_refMultipleDDicts ++ * ZSTD_d_disableHuffmanAssembly ++ * ZSTD_d_maxBlockSize + * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them. + * note : never ever use experimentalParam? names directly + */ + ZSTD_d_experimentalParam1=1000, + ZSTD_d_experimentalParam2=1001, + ZSTD_d_experimentalParam3=1002, +- ZSTD_d_experimentalParam4=1003 ++ ZSTD_d_experimentalParam4=1003, ++ ZSTD_d_experimentalParam5=1004, ++ ZSTD_d_experimentalParam6=1005 + + } ZSTD_dParameter; + +@@ -604,14 +674,14 @@ typedef struct ZSTD_outBuffer_s { + * A ZSTD_CStream object is required to track streaming operation. + * Use ZSTD_createCStream() and ZSTD_freeCStream() to create/release resources. + * ZSTD_CStream objects can be reused multiple times on consecutive compression operations. +-* It is recommended to re-use ZSTD_CStream since it will play nicer with system's memory, by re-using already allocated memory. ++* It is recommended to reuse ZSTD_CStream since it will play nicer with system's memory, by re-using already allocated memory. + * + * For parallel execution, use one separate ZSTD_CStream per thread. + * + * note : since v1.3.0, ZSTD_CStream and ZSTD_CCtx are the same thing. + * + * Parameters are sticky : when starting a new compression on the same context, +-* it will re-use the same sticky parameters as previous compression session. ++* it will reuse the same sticky parameters as previous compression session. + * When in doubt, it's recommended to fully initialize the context before usage. + * Use ZSTD_CCtx_reset() to reset the context and ZSTD_CCtx_setParameter(), + * ZSTD_CCtx_setPledgedSrcSize(), or ZSTD_CCtx_loadDictionary() and friends to +@@ -700,6 +770,11 @@ typedef enum { + * only ZSTD_e_end or ZSTD_e_flush operations are allowed. + * Before starting a new compression job, or changing compression parameters, + * it is required to fully flush internal buffers. ++ * - note: if an operation ends with an error, it may leave @cctx in an undefined state. ++ * Therefore, it's UB to invoke ZSTD_compressStream2() of ZSTD_compressStream() on such a state. ++ * In order to be re-employed after an error, a state must be reset, ++ * which can be done explicitly (ZSTD_CCtx_reset()), ++ * or is sometimes implied by methods starting a new compression job (ZSTD_initCStream(), ZSTD_compressCCtx()) + */ + ZSTDLIB_API size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, + ZSTD_outBuffer* output, +@@ -728,8 +803,6 @@ ZSTDLIB_API size_t ZSTD_CStreamOutSize(void); /*< recommended size for output + * This following is a legacy streaming API, available since v1.0+ . + * It can be replaced by ZSTD_CCtx_reset() and ZSTD_compressStream2(). + * It is redundant, but remains fully supported. +- * Streaming in combination with advanced parameters and dictionary compression +- * can only be used through the new API. + ******************************************************************************/ + + /*! +@@ -738,6 +811,9 @@ ZSTDLIB_API size_t ZSTD_CStreamOutSize(void); /*< recommended size for output + * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); + * ZSTD_CCtx_refCDict(zcs, NULL); // clear the dictionary (if any) + * ZSTD_CCtx_setParameter(zcs, ZSTD_c_compressionLevel, compressionLevel); ++ * ++ * Note that ZSTD_initCStream() clears any previously set dictionary. Use the new API ++ * to compress with a dictionary. + */ + ZSTDLIB_API size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel); + /*! +@@ -758,7 +834,7 @@ ZSTDLIB_API size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output); + * + * A ZSTD_DStream object is required to track streaming operations. + * Use ZSTD_createDStream() and ZSTD_freeDStream() to create/release resources. +-* ZSTD_DStream objects can be re-used multiple times. ++* ZSTD_DStream objects can be reused multiple times. + * + * Use ZSTD_initDStream() to start a new decompression operation. + * @return : recommended first input size +@@ -788,13 +864,37 @@ ZSTDLIB_API size_t ZSTD_freeDStream(ZSTD_DStream* zds); /* accept NULL pointer + + /*===== Streaming decompression functions =====*/ + +-/* This function is redundant with the advanced API and equivalent to: ++/*! ZSTD_initDStream() : ++ * Initialize/reset DStream state for new decompression operation. ++ * Call before new decompression operation using same DStream. + * ++ * Note : This function is redundant with the advanced API and equivalent to: + * ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); + * ZSTD_DCtx_refDDict(zds, NULL); + */ + ZSTDLIB_API size_t ZSTD_initDStream(ZSTD_DStream* zds); + ++/*! ZSTD_decompressStream() : ++ * Streaming decompression function. ++ * Call repetitively to consume full input updating it as necessary. ++ * Function will update both input and output `pos` fields exposing current state via these fields: ++ * - `input.pos < input.size`, some input remaining and caller should provide remaining input ++ * on the next call. ++ * - `output.pos < output.size`, decoder finished and flushed all remaining buffers. ++ * - `output.pos == output.size`, potentially uncflushed data present in the internal buffers, ++ * call ZSTD_decompressStream() again to flush remaining data to output. ++ * Note : with no additional input, amount of data flushed <= ZSTD_BLOCKSIZE_MAX. ++ * ++ * @return : 0 when a frame is completely decoded and fully flushed, ++ * or an error code, which can be tested using ZSTD_isError(), ++ * or any other value > 0, which means there is some decoding or flushing to do to complete current frame. ++ * ++ * Note: when an operation returns with an error code, the @zds state may be left in undefined state. ++ * It's UB to invoke `ZSTD_decompressStream()` on such a state. ++ * In order to re-use such a state, it must be first reset, ++ * which can be done explicitly (`ZSTD_DCtx_reset()`), ++ * or is implied for operations starting some new decompression job (`ZSTD_initDStream`, `ZSTD_decompressDCtx()`, `ZSTD_decompress_usingDict()`) ++ */ + ZSTDLIB_API size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inBuffer* input); + + ZSTDLIB_API size_t ZSTD_DStreamInSize(void); /*!< recommended size for input buffer */ +@@ -913,7 +1013,7 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict* ddict); + * If @return == 0, the dictID could not be decoded. + * This could for one of the following reasons : + * - The frame does not require a dictionary to be decoded (most common case). +- * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden information. ++ * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden piece of information. + * Note : this use case also happens when using a non-conformant dictionary. + * - `srcSize` is too small, and as a result, the frame header could not be decoded (only possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`). + * - This is not a Zstandard frame. +@@ -925,9 +1025,11 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); + * Advanced dictionary and prefix API (Requires v1.4.0+) + * + * This API allows dictionaries to be used with ZSTD_compress2(), +- * ZSTD_compressStream2(), and ZSTD_decompressDCtx(). Dictionaries are sticky, and +- * only reset with the context is reset with ZSTD_reset_parameters or +- * ZSTD_reset_session_and_parameters. Prefixes are single-use. ++ * ZSTD_compressStream2(), and ZSTD_decompressDCtx(). ++ * Dictionaries are sticky, they remain valid when same context is reused, ++ * they only reset when the context is reset ++ * with ZSTD_reset_parameters or ZSTD_reset_session_and_parameters. ++ * In contrast, Prefixes are single-use. + ******************************************************************************/ + + +@@ -937,8 +1039,9 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); + * @result : 0, or an error code (which can be tested with ZSTD_isError()). + * Special: Loading a NULL (or 0-size) dictionary invalidates previous dictionary, + * meaning "return to no-dictionary mode". +- * Note 1 : Dictionary is sticky, it will be used for all future compressed frames. +- * To return to "no-dictionary" situation, load a NULL dictionary (or reset parameters). ++ * Note 1 : Dictionary is sticky, it will be used for all future compressed frames, ++ * until parameters are reset, a new dictionary is loaded, or the dictionary ++ * is explicitly invalidated by loading a NULL dictionary. + * Note 2 : Loading a dictionary involves building tables. + * It's also a CPU consuming operation, with non-negligible impact on latency. + * Tables are dependent on compression parameters, and for this reason, +@@ -947,11 +1050,15 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); + * Use experimental ZSTD_CCtx_loadDictionary_byReference() to reference content instead. + * In such a case, dictionary buffer must outlive its users. + * Note 4 : Use ZSTD_CCtx_loadDictionary_advanced() +- * to precisely select how dictionary content must be interpreted. */ ++ * to precisely select how dictionary content must be interpreted. ++ * Note 5 : This method does not benefit from LDM (long distance mode). ++ * If you want to employ LDM on some large dictionary content, ++ * prefer employing ZSTD_CCtx_refPrefix() described below. ++ */ + ZSTDLIB_API size_t ZSTD_CCtx_loadDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dictSize); + + /*! ZSTD_CCtx_refCDict() : Requires v1.4.0+ +- * Reference a prepared dictionary, to be used for all next compressed frames. ++ * Reference a prepared dictionary, to be used for all future compressed frames. + * Note that compression parameters are enforced from within CDict, + * and supersede any compression parameter previously set within CCtx. + * The parameters ignored are labelled as "superseded-by-cdict" in the ZSTD_cParameter enum docs. +@@ -970,6 +1077,7 @@ ZSTDLIB_API size_t ZSTD_CCtx_refCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); + * Decompression will need same prefix to properly regenerate data. + * Compressing with a prefix is similar in outcome as performing a diff and compressing it, + * but performs much faster, especially during decompression (compression speed is tunable with compression level). ++ * This method is compatible with LDM (long distance mode). + * @result : 0, or an error code (which can be tested with ZSTD_isError()). + * Special: Adding any prefix (including NULL) invalidates any previous prefix or dictionary + * Note 1 : Prefix buffer is referenced. It **must** outlive compression. +@@ -986,9 +1094,9 @@ ZSTDLIB_API size_t ZSTD_CCtx_refPrefix(ZSTD_CCtx* cctx, + const void* prefix, size_t prefixSize); + + /*! ZSTD_DCtx_loadDictionary() : Requires v1.4.0+ +- * Create an internal DDict from dict buffer, +- * to be used to decompress next frames. +- * The dictionary remains valid for all future frames, until explicitly invalidated. ++ * Create an internal DDict from dict buffer, to be used to decompress all future frames. ++ * The dictionary remains valid for all future frames, until explicitly invalidated, or ++ * a new dictionary is loaded. + * @result : 0, or an error code (which can be tested with ZSTD_isError()). + * Special : Adding a NULL (or 0-size) dictionary invalidates any previous dictionary, + * meaning "return to no-dictionary mode". +@@ -1012,9 +1120,10 @@ ZSTDLIB_API size_t ZSTD_DCtx_loadDictionary(ZSTD_DCtx* dctx, const void* dict, s + * The memory for the table is allocated on the first call to refDDict, and can be + * freed with ZSTD_freeDCtx(). + * ++ * If called with ZSTD_d_refMultipleDDicts disabled (the default), only one dictionary ++ * will be managed, and referencing a dictionary effectively "discards" any previous one. ++ * + * @result : 0, or an error code (which can be tested with ZSTD_isError()). +- * Note 1 : Currently, only one dictionary can be managed. +- * Referencing a new dictionary effectively "discards" any previous one. + * Special: referencing a NULL DDict means "return to no-dictionary mode". + * Note 2 : DDict is just referenced, its lifetime must outlive its usage from DCtx. + */ +@@ -1071,24 +1180,6 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); + #define ZSTDLIB_STATIC_API ZSTDLIB_VISIBLE + #endif + +-/* Deprecation warnings : +- * Should these warnings be a problem, it is generally possible to disable them, +- * typically with -Wno-deprecated-declarations for gcc or _CRT_SECURE_NO_WARNINGS in Visual. +- * Otherwise, it's also possible to define ZSTD_DISABLE_DEPRECATE_WARNINGS. +- */ +-#ifdef ZSTD_DISABLE_DEPRECATE_WARNINGS +-# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API /* disable deprecation warnings */ +-#else +-# if (defined(GNUC) && (GNUC > 4 || (GNUC == 4 && GNUC_MINOR >= 5))) || defined(__clang__) +-# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API __attribute__((deprecated(message))) +-# elif (__GNUC__ >= 3) +-# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API __attribute__((deprecated)) +-# else +-# pragma message("WARNING: You need to implement ZSTD_DEPRECATED for this compiler") +-# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API +-# endif +-#endif /* ZSTD_DISABLE_DEPRECATE_WARNINGS */ +- + /* ************************************************************************************** + * experimental API (static linking only) + **************************************************************************************** +@@ -1123,6 +1214,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); + #define ZSTD_TARGETLENGTH_MIN 0 /* note : comparing this constant to an unsigned results in a tautological test */ + #define ZSTD_STRATEGY_MIN ZSTD_fast + #define ZSTD_STRATEGY_MAX ZSTD_btultra2 ++#define ZSTD_BLOCKSIZE_MAX_MIN (1 << 10) /* The minimum valid max blocksize. Maximum blocksizes smaller than this make compressBound() inaccurate. */ + + + #define ZSTD_OVERLAPLOG_MIN 0 +@@ -1146,7 +1238,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); + #define ZSTD_LDM_HASHRATELOG_MAX (ZSTD_WINDOWLOG_MAX - ZSTD_HASHLOG_MIN) + + /* Advanced parameter bounds */ +-#define ZSTD_TARGETCBLOCKSIZE_MIN 64 ++#define ZSTD_TARGETCBLOCKSIZE_MIN 1340 /* suitable to fit into an ethernet / wifi / 4G transport frame */ + #define ZSTD_TARGETCBLOCKSIZE_MAX ZSTD_BLOCKSIZE_MAX + #define ZSTD_SRCSIZEHINT_MIN 0 + #define ZSTD_SRCSIZEHINT_MAX INT_MAX +@@ -1303,7 +1395,7 @@ typedef enum { + } ZSTD_paramSwitch_e; + + /* ************************************* +-* Frame size functions ++* Frame header and size functions + ***************************************/ + + /*! ZSTD_findDecompressedSize() : +@@ -1350,29 +1442,122 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size + * or an error code (if srcSize is too small) */ + ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize); + ++typedef enum { ZSTD_frame, ZSTD_skippableFrame } ZSTD_frameType_e; ++typedef struct { ++ unsigned long long frameContentSize; /* if == ZSTD_CONTENTSIZE_UNKNOWN, it means this field is not available. 0 means "empty" */ ++ unsigned long long windowSize; /* can be very large, up to <= frameContentSize */ ++ unsigned blockSizeMax; ++ ZSTD_frameType_e frameType; /* if == ZSTD_skippableFrame, frameContentSize is the size of skippable content */ ++ unsigned headerSize; ++ unsigned dictID; ++ unsigned checksumFlag; ++ unsigned _reserved1; ++ unsigned _reserved2; ++} ZSTD_frameHeader; ++ ++/*! ZSTD_getFrameHeader() : ++ * decode Frame Header, or requires larger `srcSize`. ++ * @return : 0, `zfhPtr` is correctly filled, ++ * >0, `srcSize` is too small, value is wanted `srcSize` amount, ++ * or an error code, which can be tested using ZSTD_isError() */ ++ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize); /*< doesn't consume input */ ++/*! ZSTD_getFrameHeader_advanced() : ++ * same as ZSTD_getFrameHeader(), ++ * with added capability to select a format (like ZSTD_f_zstd1_magicless) */ ++ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format); ++ ++/*! ZSTD_decompressionMargin() : ++ * Zstd supports in-place decompression, where the input and output buffers overlap. ++ * In this case, the output buffer must be at least (Margin + Output_Size) bytes large, ++ * and the input buffer must be at the end of the output buffer. ++ * ++ * _______________________ Output Buffer ________________________ ++ * | | ++ * | ____ Input Buffer ____| ++ * | | | ++ * v v v ++ * |---------------------------------------|-----------|----------| ++ * ^ ^ ^ ++ * |___________________ Output_Size ___________________|_ Margin _| ++ * ++ * NOTE: See also ZSTD_DECOMPRESSION_MARGIN(). ++ * NOTE: This applies only to single-pass decompression through ZSTD_decompress() or ++ * ZSTD_decompressDCtx(). ++ * NOTE: This function supports multi-frame input. ++ * ++ * @param src The compressed frame(s) ++ * @param srcSize The size of the compressed frame(s) ++ * @returns The decompression margin or an error that can be checked with ZSTD_isError(). ++ */ ++ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize); ++ ++/*! ZSTD_DECOMPRESS_MARGIN() : ++ * Similar to ZSTD_decompressionMargin(), but instead of computing the margin from ++ * the compressed frame, compute it from the original size and the blockSizeLog. ++ * See ZSTD_decompressionMargin() for details. ++ * ++ * WARNING: This macro does not support multi-frame input, the input must be a single ++ * zstd frame. If you need that support use the function, or implement it yourself. ++ * ++ * @param originalSize The original uncompressed size of the data. ++ * @param blockSize The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX). ++ * Unless you explicitly set the windowLog smaller than ++ * ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX. ++ */ ++#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)( \ ++ ZSTD_FRAMEHEADERSIZE_MAX /* Frame header */ + \ ++ 4 /* checksum */ + \ ++ ((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \ ++ (blockSize) /* One block of margin */ \ ++ )) ++ + typedef enum { + ZSTD_sf_noBlockDelimiters = 0, /* Representation of ZSTD_Sequence has no block delimiters, sequences only */ + ZSTD_sf_explicitBlockDelimiters = 1 /* Representation of ZSTD_Sequence contains explicit block delimiters */ + } ZSTD_sequenceFormat_e; + ++/*! ZSTD_sequenceBound() : ++ * `srcSize` : size of the input buffer ++ * @return : upper-bound for the number of sequences that can be generated ++ * from a buffer of srcSize bytes ++ * ++ * note : returns number of sequences - to get bytes, multiply by sizeof(ZSTD_Sequence). ++ */ ++ZSTDLIB_STATIC_API size_t ZSTD_sequenceBound(size_t srcSize); ++ + /*! ZSTD_generateSequences() : +- * Generate sequences using ZSTD_compress2, given a source buffer. ++ * WARNING: This function is meant for debugging and informational purposes ONLY! ++ * Its implementation is flawed, and it will be deleted in a future version. ++ * It is not guaranteed to succeed, as there are several cases where it will give ++ * up and fail. You should NOT use this function in production code. ++ * ++ * This function is deprecated, and will be removed in a future version. ++ * ++ * Generate sequences using ZSTD_compress2(), given a source buffer. ++ * ++ * @param zc The compression context to be used for ZSTD_compress2(). Set any ++ * compression parameters you need on this context. ++ * @param outSeqs The output sequences buffer of size @p outSeqsSize ++ * @param outSeqsSize The size of the output sequences buffer. ++ * ZSTD_sequenceBound(srcSize) is an upper bound on the number ++ * of sequences that can be generated. ++ * @param src The source buffer to generate sequences from of size @p srcSize. ++ * @param srcSize The size of the source buffer. + * + * Each block will end with a dummy sequence + * with offset == 0, matchLength == 0, and litLength == length of last literals. + * litLength may be == 0, and if so, then the sequence of (of: 0 ml: 0 ll: 0) + * simply acts as a block delimiter. + * +- * zc can be used to insert custom compression params. +- * This function invokes ZSTD_compress2 +- * +- * The output of this function can be fed into ZSTD_compressSequences() with CCtx +- * setting of ZSTD_c_blockDelimiters as ZSTD_sf_explicitBlockDelimiters +- * @return : number of sequences generated ++ * @returns The number of sequences generated, necessarily less than ++ * ZSTD_sequenceBound(srcSize), or an error code that can be checked ++ * with ZSTD_isError(). + */ +- +-ZSTDLIB_STATIC_API size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, +- size_t outSeqsSize, const void* src, size_t srcSize); ++ZSTD_DEPRECATED("For debugging only, will be replaced by ZSTD_extractSequences()") ++ZSTDLIB_STATIC_API size_t ++ZSTD_generateSequences(ZSTD_CCtx* zc, ++ ZSTD_Sequence* outSeqs, size_t outSeqsSize, ++ const void* src, size_t srcSize); + + /*! ZSTD_mergeBlockDelimiters() : + * Given an array of ZSTD_Sequence, remove all sequences that represent block delimiters/last literals +@@ -1388,7 +1573,9 @@ ZSTDLIB_STATIC_API size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* o + ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, size_t seqsSize); + + /*! ZSTD_compressSequences() : +- * Compress an array of ZSTD_Sequence, generated from the original source buffer, into dst. ++ * Compress an array of ZSTD_Sequence, associated with @src buffer, into dst. ++ * @src contains the entire input (not just the literals). ++ * If @srcSize > sum(sequence.length), the remaining bytes are considered all literals + * If a dictionary is included, then the cctx should reference the dict. (see: ZSTD_CCtx_refCDict(), ZSTD_CCtx_loadDictionary(), etc.) + * The entire source is compressed into a single frame. + * +@@ -1413,11 +1600,12 @@ ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, si + * Note: Repcodes are, as of now, always re-calculated within this function, so ZSTD_Sequence::rep is unused. + * Note 2: Once we integrate ability to ingest repcodes, the explicit block delims mode must respect those repcodes exactly, + * and cannot emit an RLE block that disagrees with the repcode history +- * @return : final compressed size or a ZSTD error. ++ * @return : final compressed size, or a ZSTD error code. + */ +-ZSTDLIB_STATIC_API size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstSize, +- const ZSTD_Sequence* inSeqs, size_t inSeqsSize, +- const void* src, size_t srcSize); ++ZSTDLIB_STATIC_API size_t ++ZSTD_compressSequences( ZSTD_CCtx* cctx, void* dst, size_t dstSize, ++ const ZSTD_Sequence* inSeqs, size_t inSeqsSize, ++ const void* src, size_t srcSize); + + + /*! ZSTD_writeSkippableFrame() : +@@ -1464,48 +1652,59 @@ ZSTDLIB_API unsigned ZSTD_isSkippableFrame(const void* buffer, size_t size); + /*! ZSTD_estimate*() : + * These functions make it possible to estimate memory usage + * of a future {D,C}Ctx, before its creation. ++ * This is useful in combination with ZSTD_initStatic(), ++ * which makes it possible to employ a static buffer for ZSTD_CCtx* state. + * + * ZSTD_estimateCCtxSize() will provide a memory budget large enough +- * for any compression level up to selected one. +- * Note : Unlike ZSTD_estimateCStreamSize*(), this estimate +- * does not include space for a window buffer. +- * Therefore, the estimation is only guaranteed for single-shot compressions, not streaming. ++ * to compress data of any size using one-shot compression ZSTD_compressCCtx() or ZSTD_compress2() ++ * associated with any compression level up to max specified one. + * The estimate will assume the input may be arbitrarily large, + * which is the worst case. + * ++ * Note that the size estimation is specific for one-shot compression, ++ * it is not valid for streaming (see ZSTD_estimateCStreamSize*()) ++ * nor other potential ways of using a ZSTD_CCtx* state. ++ * + * When srcSize can be bound by a known and rather "small" value, +- * this fact can be used to provide a tighter estimation +- * because the CCtx compression context will need less memory. +- * This tighter estimation can be provided by more advanced functions ++ * this knowledge can be used to provide a tighter budget estimation ++ * because the ZSTD_CCtx* state will need less memory for small inputs. ++ * This tighter estimation can be provided by employing more advanced functions + * ZSTD_estimateCCtxSize_usingCParams(), which can be used in tandem with ZSTD_getCParams(), + * and ZSTD_estimateCCtxSize_usingCCtxParams(), which can be used in tandem with ZSTD_CCtxParams_setParameter(). + * Both can be used to estimate memory using custom compression parameters and arbitrary srcSize limits. + * +- * Note 2 : only single-threaded compression is supported. ++ * Note : only single-threaded compression is supported. + * ZSTD_estimateCCtxSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1. + */ +-ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int compressionLevel); ++ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int maxCompressionLevel); + ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams); + ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params); + ZSTDLIB_STATIC_API size_t ZSTD_estimateDCtxSize(void); + + /*! ZSTD_estimateCStreamSize() : +- * ZSTD_estimateCStreamSize() will provide a budget large enough for any compression level up to selected one. +- * It will also consider src size to be arbitrarily "large", which is worst case. ++ * ZSTD_estimateCStreamSize() will provide a memory budget large enough for streaming compression ++ * using any compression level up to the max specified one. ++ * It will also consider src size to be arbitrarily "large", which is a worst case scenario. + * If srcSize is known to always be small, ZSTD_estimateCStreamSize_usingCParams() can provide a tighter estimation. + * ZSTD_estimateCStreamSize_usingCParams() can be used in tandem with ZSTD_getCParams() to create cParams from compressionLevel. + * ZSTD_estimateCStreamSize_usingCCtxParams() can be used in tandem with ZSTD_CCtxParams_setParameter(). Only single-threaded compression is supported. This function will return an error code if ZSTD_c_nbWorkers is >= 1. + * Note : CStream size estimation is only correct for single-threaded compression. +- * ZSTD_DStream memory budget depends on window Size. ++ * ZSTD_estimateCStreamSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1. ++ * Note 2 : ZSTD_estimateCStreamSize* functions are not compatible with the Block-Level Sequence Producer API at this time. ++ * Size estimates assume that no external sequence producer is registered. ++ * ++ * ZSTD_DStream memory budget depends on frame's window Size. + * This information can be passed manually, using ZSTD_estimateDStreamSize, + * or deducted from a valid frame Header, using ZSTD_estimateDStreamSize_fromFrame(); ++ * Any frame requesting a window size larger than max specified one will be rejected. + * Note : if streaming is init with function ZSTD_init?Stream_usingDict(), + * an internal ?Dict will be created, which additional size is not estimated here. +- * In this case, get total size by adding ZSTD_estimate?DictSize */ +-ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize(int compressionLevel); ++ * In this case, get total size by adding ZSTD_estimate?DictSize ++ */ ++ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize(int maxCompressionLevel); + ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize_usingCParams(ZSTD_compressionParameters cParams); + ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params); +-ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize(size_t windowSize); ++ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize(size_t maxWindowSize); + ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize_fromFrame(const void* src, size_t srcSize); + + /*! ZSTD_estimate?DictSize() : +@@ -1649,22 +1848,45 @@ ZSTDLIB_STATIC_API size_t ZSTD_checkCParams(ZSTD_compressionParameters params); + * This function never fails (wide contract) */ + ZSTDLIB_STATIC_API ZSTD_compressionParameters ZSTD_adjustCParams(ZSTD_compressionParameters cPar, unsigned long long srcSize, size_t dictSize); + ++/*! ZSTD_CCtx_setCParams() : ++ * Set all parameters provided within @p cparams into the working @p cctx. ++ * Note : if modifying parameters during compression (MT mode only), ++ * note that changes to the .windowLog parameter will be ignored. ++ * @return 0 on success, or an error code (can be checked with ZSTD_isError()). ++ * On failure, no parameters are updated. ++ */ ++ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setCParams(ZSTD_CCtx* cctx, ZSTD_compressionParameters cparams); ++ ++/*! ZSTD_CCtx_setFParams() : ++ * Set all parameters provided within @p fparams into the working @p cctx. ++ * @return 0 on success, or an error code (can be checked with ZSTD_isError()). ++ */ ++ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setFParams(ZSTD_CCtx* cctx, ZSTD_frameParameters fparams); ++ ++/*! ZSTD_CCtx_setParams() : ++ * Set all parameters provided within @p params into the working @p cctx. ++ * @return 0 on success, or an error code (can be checked with ZSTD_isError()). ++ */ ++ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setParams(ZSTD_CCtx* cctx, ZSTD_parameters params); ++ + /*! ZSTD_compress_advanced() : + * Note : this function is now DEPRECATED. + * It can be replaced by ZSTD_compress2(), in combination with ZSTD_CCtx_setParameter() and other parameter setters. + * This prototype will generate compilation warnings. */ + ZSTD_DEPRECATED("use ZSTD_compress2") ++ZSTDLIB_STATIC_API + size_t ZSTD_compress_advanced(ZSTD_CCtx* cctx, +- void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, +- const void* dict,size_t dictSize, +- ZSTD_parameters params); ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize, ++ const void* dict,size_t dictSize, ++ ZSTD_parameters params); + + /*! ZSTD_compress_usingCDict_advanced() : + * Note : this function is now DEPRECATED. + * It can be replaced by ZSTD_compress2(), in combination with ZSTD_CCtx_loadDictionary() and other parameter setters. + * This prototype will generate compilation warnings. */ + ZSTD_DEPRECATED("use ZSTD_compress2 with ZSTD_CCtx_loadDictionary") ++ZSTDLIB_STATIC_API + size_t ZSTD_compress_usingCDict_advanced(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, +@@ -1737,11 +1959,6 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo + */ + #define ZSTD_c_literalCompressionMode ZSTD_c_experimentalParam5 + +-/* Tries to fit compressed block size to be around targetCBlockSize. +- * No target when targetCBlockSize == 0. +- * There is no guarantee on compressed block size (default:0) */ +-#define ZSTD_c_targetCBlockSize ZSTD_c_experimentalParam6 +- + /* User's best guess of source size. + * Hint is not valid when srcSizeHint == 0. + * There is no guarantee that hint is close to actual source size, +@@ -1808,13 +2025,16 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo + * Experimental parameter. + * Default is 0 == disabled. Set to 1 to enable. + * +- * Tells the compressor that the ZSTD_inBuffer will ALWAYS be the same +- * between calls, except for the modifications that zstd makes to pos (the +- * caller must not modify pos). This is checked by the compressor, and +- * compression will fail if it ever changes. This means the only flush +- * mode that makes sense is ZSTD_e_end, so zstd will error if ZSTD_e_end +- * is not used. The data in the ZSTD_inBuffer in the range [src, src + pos) +- * MUST not be modified during compression or you will get data corruption. ++ * Tells the compressor that input data presented with ZSTD_inBuffer ++ * will ALWAYS be the same between calls. ++ * Technically, the @src pointer must never be changed, ++ * and the @pos field can only be updated by zstd. ++ * However, it's possible to increase the @size field, ++ * allowing scenarios where more data can be appended after compressions starts. ++ * These conditions are checked by the compressor, ++ * and compression will fail if they are not respected. ++ * Also, data in the ZSTD_inBuffer within the range [src, src + pos) ++ * MUST not be modified during compression or it will result in data corruption. + * + * When this flag is enabled zstd won't allocate an input window buffer, + * because the user guarantees it can reference the ZSTD_inBuffer until +@@ -1822,18 +2042,15 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo + * large enough to fit a block (see ZSTD_c_stableOutBuffer). This will also + * avoid the memcpy() from the input buffer to the input window buffer. + * +- * NOTE: ZSTD_compressStream2() will error if ZSTD_e_end is not used. +- * That means this flag cannot be used with ZSTD_compressStream(). +- * + * NOTE: So long as the ZSTD_inBuffer always points to valid memory, using + * this flag is ALWAYS memory safe, and will never access out-of-bounds +- * memory. However, compression WILL fail if you violate the preconditions. ++ * memory. However, compression WILL fail if conditions are not respected. + * +- * WARNING: The data in the ZSTD_inBuffer in the range [dst, dst + pos) MUST +- * not be modified during compression or you will get data corruption. This +- * is because zstd needs to reference data in the ZSTD_inBuffer to find ++ * WARNING: The data in the ZSTD_inBuffer in the range [src, src + pos) MUST ++ * not be modified during compression or it will result in data corruption. ++ * This is because zstd needs to reference data in the ZSTD_inBuffer to find + * matches. Normally zstd maintains its own window buffer for this purpose, +- * but passing this flag tells zstd to use the user provided buffer. ++ * but passing this flag tells zstd to rely on user provided buffer instead. + */ + #define ZSTD_c_stableInBuffer ZSTD_c_experimentalParam9 + +@@ -1878,7 +2095,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo + * Without validation, providing a sequence that does not conform to the zstd spec will cause + * undefined behavior, and may produce a corrupted block. + * +- * With validation enabled, a if sequence is invalid (see doc/zstd_compression_format.md for ++ * With validation enabled, if sequence is invalid (see doc/zstd_compression_format.md for + * specifics regarding offset/matchlength requirements) then the function will bail out and + * return an error. + * +@@ -1928,6 +2145,79 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo + */ + #define ZSTD_c_deterministicRefPrefix ZSTD_c_experimentalParam15 + ++/* ZSTD_c_prefetchCDictTables ++ * Controlled with ZSTD_paramSwitch_e enum. Default is ZSTD_ps_auto. ++ * ++ * In some situations, zstd uses CDict tables in-place rather than copying them ++ * into the working context. (See docs on ZSTD_dictAttachPref_e above for details). ++ * In such situations, compression speed is seriously impacted when CDict tables are ++ * "cold" (outside CPU cache). This parameter instructs zstd to prefetch CDict tables ++ * when they are used in-place. ++ * ++ * For sufficiently small inputs, the cost of the prefetch will outweigh the benefit. ++ * For sufficiently large inputs, zstd will by default memcpy() CDict tables ++ * into the working context, so there is no need to prefetch. This parameter is ++ * targeted at a middle range of input sizes, where a prefetch is cheap enough to be ++ * useful but memcpy() is too expensive. The exact range of input sizes where this ++ * makes sense is best determined by careful experimentation. ++ * ++ * Note: for this parameter, ZSTD_ps_auto is currently equivalent to ZSTD_ps_disable, ++ * but in the future zstd may conditionally enable this feature via an auto-detection ++ * heuristic for cold CDicts. ++ * Use ZSTD_ps_disable to opt out of prefetching under any circumstances. ++ */ ++#define ZSTD_c_prefetchCDictTables ZSTD_c_experimentalParam16 ++ ++/* ZSTD_c_enableSeqProducerFallback ++ * Allowed values are 0 (disable) and 1 (enable). The default setting is 0. ++ * ++ * Controls whether zstd will fall back to an internal sequence producer if an ++ * external sequence producer is registered and returns an error code. This fallback ++ * is block-by-block: the internal sequence producer will only be called for blocks ++ * where the external sequence producer returns an error code. Fallback parsing will ++ * follow any other cParam settings, such as compression level, the same as in a ++ * normal (fully-internal) compression operation. ++ * ++ * The user is strongly encouraged to read the full Block-Level Sequence Producer API ++ * documentation (below) before setting this parameter. */ ++#define ZSTD_c_enableSeqProducerFallback ZSTD_c_experimentalParam17 ++ ++/* ZSTD_c_maxBlockSize ++ * Allowed values are between 1KB and ZSTD_BLOCKSIZE_MAX (128KB). ++ * The default is ZSTD_BLOCKSIZE_MAX, and setting to 0 will set to the default. ++ * ++ * This parameter can be used to set an upper bound on the blocksize ++ * that overrides the default ZSTD_BLOCKSIZE_MAX. It cannot be used to set upper ++ * bounds greater than ZSTD_BLOCKSIZE_MAX or bounds lower than 1KB (will make ++ * compressBound() inaccurate). Only currently meant to be used for testing. ++ * ++ */ ++#define ZSTD_c_maxBlockSize ZSTD_c_experimentalParam18 ++ ++/* ZSTD_c_searchForExternalRepcodes ++ * This parameter affects how zstd parses external sequences, such as sequences ++ * provided through the compressSequences() API or from an external block-level ++ * sequence producer. ++ * ++ * If set to ZSTD_ps_enable, the library will check for repeated offsets in ++ * external sequences, even if those repcodes are not explicitly indicated in ++ * the "rep" field. Note that this is the only way to exploit repcode matches ++ * while using compressSequences() or an external sequence producer, since zstd ++ * currently ignores the "rep" field of external sequences. ++ * ++ * If set to ZSTD_ps_disable, the library will not exploit repeated offsets in ++ * external sequences, regardless of whether the "rep" field has been set. This ++ * reduces sequence compression overhead by about 25% while sacrificing some ++ * compression ratio. ++ * ++ * The default value is ZSTD_ps_auto, for which the library will enable/disable ++ * based on compression level. ++ * ++ * Note: for now, this param only has an effect if ZSTD_c_blockDelimiters is ++ * set to ZSTD_sf_explicitBlockDelimiters. That may change in the future. ++ */ ++#define ZSTD_c_searchForExternalRepcodes ZSTD_c_experimentalParam19 ++ + /*! ZSTD_CCtx_getParameter() : + * Get the requested compression parameter value, selected by enum ZSTD_cParameter, + * and store it into int* value. +@@ -2084,7 +2374,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete + * in the range [dst, dst + pos) MUST not be modified during decompression + * or you will get data corruption. + * +- * When this flags is enabled zstd won't allocate an output buffer, because ++ * When this flag is enabled zstd won't allocate an output buffer, because + * it can write directly to the ZSTD_outBuffer, but it will still allocate + * an input buffer large enough to fit any compressed block. This will also + * avoid the memcpy() from the internal output buffer to the ZSTD_outBuffer. +@@ -2137,6 +2427,33 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete + */ + #define ZSTD_d_refMultipleDDicts ZSTD_d_experimentalParam4 + ++/* ZSTD_d_disableHuffmanAssembly ++ * Set to 1 to disable the Huffman assembly implementation. ++ * The default value is 0, which allows zstd to use the Huffman assembly ++ * implementation if available. ++ * ++ * This parameter can be used to disable Huffman assembly at runtime. ++ * If you want to disable it at compile time you can define the macro ++ * ZSTD_DISABLE_ASM. ++ */ ++#define ZSTD_d_disableHuffmanAssembly ZSTD_d_experimentalParam5 ++ ++/* ZSTD_d_maxBlockSize ++ * Allowed values are between 1KB and ZSTD_BLOCKSIZE_MAX (128KB). ++ * The default is ZSTD_BLOCKSIZE_MAX, and setting to 0 will set to the default. ++ * ++ * Forces the decompressor to reject blocks whose content size is ++ * larger than the configured maxBlockSize. When maxBlockSize is ++ * larger than the windowSize, the windowSize is used instead. ++ * This saves memory on the decoder when you know all blocks are small. ++ * ++ * This option is typically used in conjunction with ZSTD_c_maxBlockSize. ++ * ++ * WARNING: This causes the decoder to reject otherwise valid frames ++ * that have block sizes larger than the configured maxBlockSize. ++ */ ++#define ZSTD_d_maxBlockSize ZSTD_d_experimentalParam6 ++ + + /*! ZSTD_DCtx_setFormat() : + * This function is REDUNDANT. Prefer ZSTD_DCtx_setParameter(). +@@ -2145,6 +2462,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete + * such ZSTD_f_zstd1_magicless for example. + * @return : 0, or an error code (which can be tested using ZSTD_isError()). */ + ZSTD_DEPRECATED("use ZSTD_DCtx_setParameter() instead") ++ZSTDLIB_STATIC_API + size_t ZSTD_DCtx_setFormat(ZSTD_DCtx* dctx, ZSTD_format_e format); + + /*! ZSTD_decompressStream_simpleArgs() : +@@ -2181,6 +2499,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_decompressStream_simpleArgs ( + * This prototype will generate compilation warnings. + */ + ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") ++ZSTDLIB_STATIC_API + size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, + int compressionLevel, + unsigned long long pledgedSrcSize); +@@ -2198,17 +2517,15 @@ size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, + * This prototype will generate compilation warnings. + */ + ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") ++ZSTDLIB_STATIC_API + size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, + const void* dict, size_t dictSize, + int compressionLevel); + + /*! ZSTD_initCStream_advanced() : +- * This function is DEPRECATED, and is approximately equivalent to: ++ * This function is DEPRECATED, and is equivalent to: + * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); +- * // Pseudocode: Set each zstd parameter and leave the rest as-is. +- * for ((param, value) : params) { +- * ZSTD_CCtx_setParameter(zcs, param, value); +- * } ++ * ZSTD_CCtx_setParams(zcs, params); + * ZSTD_CCtx_setPledgedSrcSize(zcs, pledgedSrcSize); + * ZSTD_CCtx_loadDictionary(zcs, dict, dictSize); + * +@@ -2218,6 +2535,7 @@ size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, + * This prototype will generate compilation warnings. + */ + ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") ++ZSTDLIB_STATIC_API + size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, + const void* dict, size_t dictSize, + ZSTD_parameters params, +@@ -2232,15 +2550,13 @@ size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, + * This prototype will generate compilation warnings. + */ + ZSTD_DEPRECATED("use ZSTD_CCtx_reset and ZSTD_CCtx_refCDict, see zstd.h for detailed instructions") ++ZSTDLIB_STATIC_API + size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict); + + /*! ZSTD_initCStream_usingCDict_advanced() : +- * This function is DEPRECATED, and is approximately equivalent to: ++ * This function is DEPRECATED, and is equivalent to: + * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); +- * // Pseudocode: Set each zstd frame parameter and leave the rest as-is. +- * for ((fParam, value) : fParams) { +- * ZSTD_CCtx_setParameter(zcs, fParam, value); +- * } ++ * ZSTD_CCtx_setFParams(zcs, fParams); + * ZSTD_CCtx_setPledgedSrcSize(zcs, pledgedSrcSize); + * ZSTD_CCtx_refCDict(zcs, cdict); + * +@@ -2250,6 +2566,7 @@ size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict); + * This prototype will generate compilation warnings. + */ + ZSTD_DEPRECATED("use ZSTD_CCtx_reset and ZSTD_CCtx_refCDict, see zstd.h for detailed instructions") ++ZSTDLIB_STATIC_API + size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, + const ZSTD_CDict* cdict, + ZSTD_frameParameters fParams, +@@ -2264,7 +2581,7 @@ size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, + * explicitly specified. + * + * start a new frame, using same parameters from previous frame. +- * This is typically useful to skip dictionary loading stage, since it will re-use it in-place. ++ * This is typically useful to skip dictionary loading stage, since it will reuse it in-place. + * Note that zcs must be init at least once before using ZSTD_resetCStream(). + * If pledgedSrcSize is not known at reset time, use macro ZSTD_CONTENTSIZE_UNKNOWN. + * If pledgedSrcSize > 0, its value must be correct, as it will be written in header, and controlled at the end. +@@ -2274,6 +2591,7 @@ size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, + * This prototype will generate compilation warnings. + */ + ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") ++ZSTDLIB_STATIC_API + size_t ZSTD_resetCStream(ZSTD_CStream* zcs, unsigned long long pledgedSrcSize); + + +@@ -2319,8 +2637,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_toFlushNow(ZSTD_CCtx* cctx); + * ZSTD_DCtx_loadDictionary(zds, dict, dictSize); + * + * note: no dictionary will be used if dict == NULL or dictSize < 8 +- * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x + */ ++ZSTD_DEPRECATED("use ZSTD_DCtx_reset + ZSTD_DCtx_loadDictionary, see zstd.h for detailed instructions") + ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t dictSize); + + /*! +@@ -2330,8 +2648,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const vo + * ZSTD_DCtx_refDDict(zds, ddict); + * + * note : ddict is referenced, it must outlive decompression session +- * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x + */ ++ZSTD_DEPRECATED("use ZSTD_DCtx_reset + ZSTD_DCtx_refDDict, see zstd.h for detailed instructions") + ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* zds, const ZSTD_DDict* ddict); + + /*! +@@ -2339,18 +2657,202 @@ ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* zds, const Z + * + * ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); + * +- * re-use decompression parameters from previous init; saves dictionary loading +- * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x ++ * reuse decompression parameters from previous init; saves dictionary loading + */ ++ZSTD_DEPRECATED("use ZSTD_DCtx_reset, see zstd.h for detailed instructions") + ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); + + ++/* ********************* BLOCK-LEVEL SEQUENCE PRODUCER API ********************* ++ * ++ * *** OVERVIEW *** ++ * The Block-Level Sequence Producer API allows users to provide their own custom ++ * sequence producer which libzstd invokes to process each block. The produced list ++ * of sequences (literals and matches) is then post-processed by libzstd to produce ++ * valid compressed blocks. ++ * ++ * This block-level offload API is a more granular complement of the existing ++ * frame-level offload API compressSequences() (introduced in v1.5.1). It offers ++ * an easier migration story for applications already integrated with libzstd: the ++ * user application continues to invoke the same compression functions ++ * ZSTD_compress2() or ZSTD_compressStream2() as usual, and transparently benefits ++ * from the specific advantages of the external sequence producer. For example, ++ * the sequence producer could be tuned to take advantage of known characteristics ++ * of the input, to offer better speed / ratio, or could leverage hardware ++ * acceleration not available within libzstd itself. ++ * ++ * See contrib/externalSequenceProducer for an example program employing the ++ * Block-Level Sequence Producer API. ++ * ++ * *** USAGE *** ++ * The user is responsible for implementing a function of type ++ * ZSTD_sequenceProducer_F. For each block, zstd will pass the following ++ * arguments to the user-provided function: ++ * ++ * - sequenceProducerState: a pointer to a user-managed state for the sequence ++ * producer. ++ * ++ * - outSeqs, outSeqsCapacity: an output buffer for the sequence producer. ++ * outSeqsCapacity is guaranteed >= ZSTD_sequenceBound(srcSize). The memory ++ * backing outSeqs is managed by the CCtx. ++ * ++ * - src, srcSize: an input buffer for the sequence producer to parse. ++ * srcSize is guaranteed to be <= ZSTD_BLOCKSIZE_MAX. ++ * ++ * - dict, dictSize: a history buffer, which may be empty, which the sequence ++ * producer may reference as it parses the src buffer. Currently, zstd will ++ * always pass dictSize == 0 into external sequence producers, but this will ++ * change in the future. ++ * ++ * - compressionLevel: a signed integer representing the zstd compression level ++ * set by the user for the current operation. The sequence producer may choose ++ * to use this information to change its compression strategy and speed/ratio ++ * tradeoff. Note: the compression level does not reflect zstd parameters set ++ * through the advanced API. ++ * ++ * - windowSize: a size_t representing the maximum allowed offset for external ++ * sequences. Note that sequence offsets are sometimes allowed to exceed the ++ * windowSize if a dictionary is present, see doc/zstd_compression_format.md ++ * for details. ++ * ++ * The user-provided function shall return a size_t representing the number of ++ * sequences written to outSeqs. This return value will be treated as an error ++ * code if it is greater than outSeqsCapacity. The return value must be non-zero ++ * if srcSize is non-zero. The ZSTD_SEQUENCE_PRODUCER_ERROR macro is provided ++ * for convenience, but any value greater than outSeqsCapacity will be treated as ++ * an error code. ++ * ++ * If the user-provided function does not return an error code, the sequences ++ * written to outSeqs must be a valid parse of the src buffer. Data corruption may ++ * occur if the parse is not valid. A parse is defined to be valid if the ++ * following conditions hold: ++ * - The sum of matchLengths and literalLengths must equal srcSize. ++ * - All sequences in the parse, except for the final sequence, must have ++ * matchLength >= ZSTD_MINMATCH_MIN. The final sequence must have ++ * matchLength >= ZSTD_MINMATCH_MIN or matchLength == 0. ++ * - All offsets must respect the windowSize parameter as specified in ++ * doc/zstd_compression_format.md. ++ * - If the final sequence has matchLength == 0, it must also have offset == 0. ++ * ++ * zstd will only validate these conditions (and fail compression if they do not ++ * hold) if the ZSTD_c_validateSequences cParam is enabled. Note that sequence ++ * validation has a performance cost. ++ * ++ * If the user-provided function returns an error, zstd will either fall back ++ * to an internal sequence producer or fail the compression operation. The user can ++ * choose between the two behaviors by setting the ZSTD_c_enableSeqProducerFallback ++ * cParam. Fallback compression will follow any other cParam settings, such as ++ * compression level, the same as in a normal compression operation. ++ * ++ * The user shall instruct zstd to use a particular ZSTD_sequenceProducer_F ++ * function by calling ++ * ZSTD_registerSequenceProducer(cctx, ++ * sequenceProducerState, ++ * sequenceProducer) ++ * This setting will persist until the next parameter reset of the CCtx. ++ * ++ * The sequenceProducerState must be initialized by the user before calling ++ * ZSTD_registerSequenceProducer(). The user is responsible for destroying the ++ * sequenceProducerState. ++ * ++ * *** LIMITATIONS *** ++ * This API is compatible with all zstd compression APIs which respect advanced parameters. ++ * However, there are three limitations: ++ * ++ * First, the ZSTD_c_enableLongDistanceMatching cParam is not currently supported. ++ * COMPRESSION WILL FAIL if it is enabled and the user tries to compress with a block-level ++ * external sequence producer. ++ * - Note that ZSTD_c_enableLongDistanceMatching is auto-enabled by default in some ++ * cases (see its documentation for details). Users must explicitly set ++ * ZSTD_c_enableLongDistanceMatching to ZSTD_ps_disable in such cases if an external ++ * sequence producer is registered. ++ * - As of this writing, ZSTD_c_enableLongDistanceMatching is disabled by default ++ * whenever ZSTD_c_windowLog < 128MB, but that's subject to change. Users should ++ * check the docs on ZSTD_c_enableLongDistanceMatching whenever the Block-Level Sequence ++ * Producer API is used in conjunction with advanced settings (like ZSTD_c_windowLog). ++ * ++ * Second, history buffers are not currently supported. Concretely, zstd will always pass ++ * dictSize == 0 to the external sequence producer (for now). This has two implications: ++ * - Dictionaries are not currently supported. Compression will *not* fail if the user ++ * references a dictionary, but the dictionary won't have any effect. ++ * - Stream history is not currently supported. All advanced compression APIs, including ++ * streaming APIs, work with external sequence producers, but each block is treated as ++ * an independent chunk without history from previous blocks. ++ * ++ * Third, multi-threading within a single compression is not currently supported. In other words, ++ * COMPRESSION WILL FAIL if ZSTD_c_nbWorkers > 0 and an external sequence producer is registered. ++ * Multi-threading across compressions is fine: simply create one CCtx per thread. ++ * ++ * Long-term, we plan to overcome all three limitations. There is no technical blocker to ++ * overcoming them. It is purely a question of engineering effort. ++ */ ++ ++#define ZSTD_SEQUENCE_PRODUCER_ERROR ((size_t)(-1)) ++ ++typedef size_t (*ZSTD_sequenceProducer_F) ( ++ void* sequenceProducerState, ++ ZSTD_Sequence* outSeqs, size_t outSeqsCapacity, ++ const void* src, size_t srcSize, ++ const void* dict, size_t dictSize, ++ int compressionLevel, ++ size_t windowSize ++); ++ ++/*! ZSTD_registerSequenceProducer() : ++ * Instruct zstd to use a block-level external sequence producer function. ++ * ++ * The sequenceProducerState must be initialized by the caller, and the caller is ++ * responsible for managing its lifetime. This parameter is sticky across ++ * compressions. It will remain set until the user explicitly resets compression ++ * parameters. ++ * ++ * Sequence producer registration is considered to be an "advanced parameter", ++ * part of the "advanced API". This means it will only have an effect on compression ++ * APIs which respect advanced parameters, such as compress2() and compressStream2(). ++ * Older compression APIs such as compressCCtx(), which predate the introduction of ++ * "advanced parameters", will ignore any external sequence producer setting. ++ * ++ * The sequence producer can be "cleared" by registering a NULL function pointer. This ++ * removes all limitations described above in the "LIMITATIONS" section of the API docs. ++ * ++ * The user is strongly encouraged to read the full API documentation (above) before ++ * calling this function. */ ++ZSTDLIB_STATIC_API void ++ZSTD_registerSequenceProducer( ++ ZSTD_CCtx* cctx, ++ void* sequenceProducerState, ++ ZSTD_sequenceProducer_F sequenceProducer ++); ++ ++/*! ZSTD_CCtxParams_registerSequenceProducer() : ++ * Same as ZSTD_registerSequenceProducer(), but operates on ZSTD_CCtx_params. ++ * This is used for accurate size estimation with ZSTD_estimateCCtxSize_usingCCtxParams(), ++ * which is needed when creating a ZSTD_CCtx with ZSTD_initStaticCCtx(). ++ * ++ * If you are using the external sequence producer API in a scenario where ZSTD_initStaticCCtx() ++ * is required, then this function is for you. Otherwise, you probably don't need it. ++ * ++ * See tests/zstreamtest.c for example usage. */ ++ZSTDLIB_STATIC_API void ++ZSTD_CCtxParams_registerSequenceProducer( ++ ZSTD_CCtx_params* params, ++ void* sequenceProducerState, ++ ZSTD_sequenceProducer_F sequenceProducer ++); ++ ++ + /* ******************************************************************* +-* Buffer-less and synchronous inner streaming functions ++* Buffer-less and synchronous inner streaming functions (DEPRECATED) ++* ++* This API is deprecated, and will be removed in a future version. ++* It allows streaming (de)compression with user allocated buffers. ++* However, it is hard to use, and not as well tested as the rest of ++* our API. + * +-* This is an advanced API, giving full control over buffer management, for users which need direct control over memory. +-* But it's also a complex one, with several restrictions, documented below. +-* Prefer normal streaming API for an easier experience. ++* Please use the normal streaming API instead: ZSTD_compressStream2, ++* and ZSTD_decompressStream. ++* If there is functionality that you need, but it doesn't provide, ++* please open an issue on our GitHub. + ********************************************************************* */ + + /* +@@ -2358,11 +2860,10 @@ ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); + + A ZSTD_CCtx object is required to track streaming operations. + Use ZSTD_createCCtx() / ZSTD_freeCCtx() to manage resource. +- ZSTD_CCtx object can be re-used multiple times within successive compression operations. ++ ZSTD_CCtx object can be reused multiple times within successive compression operations. + + Start by initializing a context. + Use ZSTD_compressBegin(), or ZSTD_compressBegin_usingDict() for dictionary compression. +- It's also possible to duplicate a reference context which has already been initialized, using ZSTD_copyCCtx() + + Then, consume your input using ZSTD_compressContinue(). + There are some important considerations to keep in mind when using this advanced function : +@@ -2380,36 +2881,46 @@ ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); + It's possible to use srcSize==0, in which case, it will write a final empty block to end the frame. + Without last block mark, frames are considered unfinished (hence corrupted) by compliant decoders. + +- `ZSTD_CCtx` object can be re-used (ZSTD_compressBegin()) to compress again. ++ `ZSTD_CCtx` object can be reused (ZSTD_compressBegin()) to compress again. + */ + + /*===== Buffer-less streaming compression functions =====*/ ++ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel); ++ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel); ++ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); /*< note: fails if cdict==NULL */ +-ZSTDLIB_STATIC_API size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); /*< note: if pledgedSrcSize is not known, use ZSTD_CONTENTSIZE_UNKNOWN */ + ++ZSTD_DEPRECATED("This function will likely be removed in a future release. It is misleading and has very limited utility.") ++ZSTDLIB_STATIC_API ++size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); /*< note: if pledgedSrcSize is not known, use ZSTD_CONTENTSIZE_UNKNOWN */ ++ ++ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); ++ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); + + /* The ZSTD_compressBegin_advanced() and ZSTD_compressBegin_usingCDict_advanced() are now DEPRECATED and will generate a compiler warning */ + ZSTD_DEPRECATED("use advanced API to access custom parameters") ++ZSTDLIB_STATIC_API + size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize); /*< pledgedSrcSize : If srcSize is not known at init time, use ZSTD_CONTENTSIZE_UNKNOWN */ + ZSTD_DEPRECATED("use advanced API to access custom parameters") ++ZSTDLIB_STATIC_API + size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_CDict* const cdict, ZSTD_frameParameters const fParams, unsigned long long const pledgedSrcSize); /* compression parameters are already set within cdict. pledgedSrcSize must be correct. If srcSize is not known, use macro ZSTD_CONTENTSIZE_UNKNOWN */ + /* + Buffer-less streaming decompression (synchronous mode) + + A ZSTD_DCtx object is required to track streaming operations. + Use ZSTD_createDCtx() / ZSTD_freeDCtx() to manage it. +- A ZSTD_DCtx object can be re-used multiple times. ++ A ZSTD_DCtx object can be reused multiple times. + + First typical operation is to retrieve frame parameters, using ZSTD_getFrameHeader(). + Frame header is extracted from the beginning of compressed frame, so providing only the frame's beginning is enough. + Data fragment must be large enough to ensure successful decoding. + `ZSTD_frameHeaderSize_max` bytes is guaranteed to always be large enough. +- @result : 0 : successful decoding, the `ZSTD_frameHeader` structure is correctly filled. +- >0 : `srcSize` is too small, please provide at least @result bytes on next attempt. ++ result : 0 : successful decoding, the `ZSTD_frameHeader` structure is correctly filled. ++ >0 : `srcSize` is too small, please provide at least result bytes on next attempt. + errorCode, which can be tested using ZSTD_isError(). + + It fills a ZSTD_frameHeader structure with important information to correctly decode the frame, +@@ -2428,7 +2939,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ + + The most memory efficient way is to use a round buffer of sufficient size. + Sufficient size is determined by invoking ZSTD_decodingBufferSize_min(), +- which can @return an error code if required value is too large for current system (in 32-bits mode). ++ which can return an error code if required value is too large for current system (in 32-bits mode). + In a round buffer methodology, ZSTD_decompressContinue() decompresses each block next to previous one, + up to the moment there is not enough room left in the buffer to guarantee decoding another full block, + which maximum size is provided in `ZSTD_frameHeader` structure, field `blockSizeMax`. +@@ -2448,7 +2959,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ + ZSTD_nextSrcSizeToDecompress() tells how many bytes to provide as 'srcSize' to ZSTD_decompressContinue(). + ZSTD_decompressContinue() requires this _exact_ amount of bytes, or it will fail. + +- @result of ZSTD_decompressContinue() is the number of bytes regenerated within 'dst' (necessarily <= dstCapacity). ++ result of ZSTD_decompressContinue() is the number of bytes regenerated within 'dst' (necessarily <= dstCapacity). + It can be zero : it just means ZSTD_decompressContinue() has decoded some metadata item. + It can also be an error code, which can be tested with ZSTD_isError(). + +@@ -2471,27 +2982,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ + */ + + /*===== Buffer-less streaming decompression functions =====*/ +-typedef enum { ZSTD_frame, ZSTD_skippableFrame } ZSTD_frameType_e; +-typedef struct { +- unsigned long long frameContentSize; /* if == ZSTD_CONTENTSIZE_UNKNOWN, it means this field is not available. 0 means "empty" */ +- unsigned long long windowSize; /* can be very large, up to <= frameContentSize */ +- unsigned blockSizeMax; +- ZSTD_frameType_e frameType; /* if == ZSTD_skippableFrame, frameContentSize is the size of skippable content */ +- unsigned headerSize; +- unsigned dictID; +- unsigned checksumFlag; +-} ZSTD_frameHeader; + +-/*! ZSTD_getFrameHeader() : +- * decode Frame Header, or requires larger `srcSize`. +- * @return : 0, `zfhPtr` is correctly filled, +- * >0, `srcSize` is too small, value is wanted `srcSize` amount, +- * or an error code, which can be tested using ZSTD_isError() */ +-ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize); /*< doesn't consume input */ +-/*! ZSTD_getFrameHeader_advanced() : +- * same as ZSTD_getFrameHeader(), +- * with added capability to select a format (like ZSTD_f_zstd1_magicless) */ +-ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format); + ZSTDLIB_STATIC_API size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize); /*< when frame content size is not known, pass in frameContentSize == ZSTD_CONTENTSIZE_UNKNOWN */ + + ZSTDLIB_STATIC_API size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx); +@@ -2502,6 +2993,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx); + ZSTDLIB_STATIC_API size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); + + /* misc */ ++ZSTD_DEPRECATED("This function will likely be removed in the next minor release. It is misleading and has very limited utility.") + ZSTDLIB_STATIC_API void ZSTD_copyDCtx(ZSTD_DCtx* dctx, const ZSTD_DCtx* preparedDCtx); + typedef enum { ZSTDnit_frameHeader, ZSTDnit_blockHeader, ZSTDnit_block, ZSTDnit_lastBlock, ZSTDnit_checksum, ZSTDnit_skippableFrame } ZSTD_nextInputType_e; + ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); +@@ -2509,11 +3001,23 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); + + + +-/* ============================ */ +-/* Block level API */ +-/* ============================ */ ++/* ========================================= */ ++/* Block level API (DEPRECATED) */ ++/* ========================================= */ + + /*! ++ ++ This API is deprecated in favor of the regular compression API. ++ You can get the frame header down to 2 bytes by setting: ++ - ZSTD_c_format = ZSTD_f_zstd1_magicless ++ - ZSTD_c_contentSizeFlag = 0 ++ - ZSTD_c_checksumFlag = 0 ++ - ZSTD_c_dictIDFlag = 0 ++ ++ This API is not as well tested as our normal API, so we recommend not using it. ++ We will be removing it in a future version. If the normal API doesn't provide ++ the functionality you need, please open a GitHub issue. ++ + Block functions produce and decode raw zstd blocks, without frame metadata. + Frame metadata cost is typically ~12 bytes, which can be non-negligible for very small blocks (< 100 bytes). + But users will have to take in charge needed metadata to regenerate data, such as compressed and content sizes. +@@ -2524,7 +3028,6 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); + - It is necessary to init context before starting + + compression : any ZSTD_compressBegin*() variant, including with dictionary + + decompression : any ZSTD_decompressBegin*() variant, including with dictionary +- + copyCCtx() and copyDCtx() can be used too + - Block size is limited, it must be <= ZSTD_getBlockSize() <= ZSTD_BLOCKSIZE_MAX == 128 KB + + If input is larger than a block size, it's necessary to split input data into multiple blocks + + For inputs larger than a single block, consider using regular ZSTD_compress() instead. +@@ -2541,11 +3044,14 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); + */ + + /*===== Raw zstd block functions =====*/ ++ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_getBlockSize (const ZSTD_CCtx* cctx); ++ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_compressBlock (ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); ++ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); ++ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") + ZSTDLIB_STATIC_API size_t ZSTD_insertBlock (ZSTD_DCtx* dctx, const void* blockStart, size_t blockSize); /*< insert uncompressed block into `dctx` history. Useful for multi-blocks decompression. */ + +- + #endif /* ZSTD_H_ZSTD_STATIC_LINKING_ONLY */ + +diff --git a/lib/zstd/Makefile b/lib/zstd/Makefile +index 20f08c644b71..464c410b2768 100644 +--- a/lib/zstd/Makefile ++++ b/lib/zstd/Makefile +@@ -1,6 +1,6 @@ + # SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + # ################################################################ +-# Copyright (c) Facebook, Inc. ++# Copyright (c) Meta Platforms, Inc. and affiliates. + # All rights reserved. + # + # This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/common/allocations.h b/lib/zstd/common/allocations.h +new file mode 100644 +index 000000000000..16c3d08e8d1a +--- /dev/null ++++ b/lib/zstd/common/allocations.h +@@ -0,0 +1,56 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ ++/* ++ * Copyright (c) Meta Platforms, Inc. and affiliates. ++ * All rights reserved. ++ * ++ * This source code is licensed under both the BSD-style license (found in the ++ * LICENSE file in the root directory of this source tree) and the GPLv2 (found ++ * in the COPYING file in the root directory of this source tree). ++ * You may select, at your option, one of the above-listed licenses. ++ */ ++ ++/* This file provides custom allocation primitives ++ */ ++ ++#define ZSTD_DEPS_NEED_MALLOC ++#include "zstd_deps.h" /* ZSTD_malloc, ZSTD_calloc, ZSTD_free, ZSTD_memset */ ++ ++#include "compiler.h" /* MEM_STATIC */ ++#define ZSTD_STATIC_LINKING_ONLY ++#include /* ZSTD_customMem */ ++ ++#ifndef ZSTD_ALLOCATIONS_H ++#define ZSTD_ALLOCATIONS_H ++ ++/* custom memory allocation functions */ ++ ++MEM_STATIC void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) ++{ ++ if (customMem.customAlloc) ++ return customMem.customAlloc(customMem.opaque, size); ++ return ZSTD_malloc(size); ++} ++ ++MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) ++{ ++ if (customMem.customAlloc) { ++ /* calloc implemented as malloc+memset; ++ * not as efficient as calloc, but next best guess for custom malloc */ ++ void* const ptr = customMem.customAlloc(customMem.opaque, size); ++ ZSTD_memset(ptr, 0, size); ++ return ptr; ++ } ++ return ZSTD_calloc(1, size); ++} ++ ++MEM_STATIC void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) ++{ ++ if (ptr!=NULL) { ++ if (customMem.customFree) ++ customMem.customFree(customMem.opaque, ptr); ++ else ++ ZSTD_free(ptr); ++ } ++} ++ ++#endif /* ZSTD_ALLOCATIONS_H */ +diff --git a/lib/zstd/common/bits.h b/lib/zstd/common/bits.h +new file mode 100644 +index 000000000000..aa3487ec4b6a +--- /dev/null ++++ b/lib/zstd/common/bits.h +@@ -0,0 +1,149 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ ++/* ++ * Copyright (c) Meta Platforms, Inc. and affiliates. ++ * All rights reserved. ++ * ++ * This source code is licensed under both the BSD-style license (found in the ++ * LICENSE file in the root directory of this source tree) and the GPLv2 (found ++ * in the COPYING file in the root directory of this source tree). ++ * You may select, at your option, one of the above-listed licenses. ++ */ ++ ++#ifndef ZSTD_BITS_H ++#define ZSTD_BITS_H ++ ++#include "mem.h" ++ ++MEM_STATIC unsigned ZSTD_countTrailingZeros32_fallback(U32 val) ++{ ++ assert(val != 0); ++ { ++ static const U32 DeBruijnBytePos[32] = {0, 1, 28, 2, 29, 14, 24, 3, ++ 30, 22, 20, 15, 25, 17, 4, 8, ++ 31, 27, 13, 23, 21, 19, 16, 7, ++ 26, 12, 18, 6, 11, 5, 10, 9}; ++ return DeBruijnBytePos[((U32) ((val & -(S32) val) * 0x077CB531U)) >> 27]; ++ } ++} ++ ++MEM_STATIC unsigned ZSTD_countTrailingZeros32(U32 val) ++{ ++ assert(val != 0); ++# if (__GNUC__ >= 4) ++ return (unsigned)__builtin_ctz(val); ++# else ++ return ZSTD_countTrailingZeros32_fallback(val); ++# endif ++} ++ ++MEM_STATIC unsigned ZSTD_countLeadingZeros32_fallback(U32 val) { ++ assert(val != 0); ++ { ++ static const U32 DeBruijnClz[32] = {0, 9, 1, 10, 13, 21, 2, 29, ++ 11, 14, 16, 18, 22, 25, 3, 30, ++ 8, 12, 20, 28, 15, 17, 24, 7, ++ 19, 27, 23, 6, 26, 5, 4, 31}; ++ val |= val >> 1; ++ val |= val >> 2; ++ val |= val >> 4; ++ val |= val >> 8; ++ val |= val >> 16; ++ return 31 - DeBruijnClz[(val * 0x07C4ACDDU) >> 27]; ++ } ++} ++ ++MEM_STATIC unsigned ZSTD_countLeadingZeros32(U32 val) ++{ ++ assert(val != 0); ++# if (__GNUC__ >= 4) ++ return (unsigned)__builtin_clz(val); ++# else ++ return ZSTD_countLeadingZeros32_fallback(val); ++# endif ++} ++ ++MEM_STATIC unsigned ZSTD_countTrailingZeros64(U64 val) ++{ ++ assert(val != 0); ++# if (__GNUC__ >= 4) && defined(__LP64__) ++ return (unsigned)__builtin_ctzll(val); ++# else ++ { ++ U32 mostSignificantWord = (U32)(val >> 32); ++ U32 leastSignificantWord = (U32)val; ++ if (leastSignificantWord == 0) { ++ return 32 + ZSTD_countTrailingZeros32(mostSignificantWord); ++ } else { ++ return ZSTD_countTrailingZeros32(leastSignificantWord); ++ } ++ } ++# endif ++} ++ ++MEM_STATIC unsigned ZSTD_countLeadingZeros64(U64 val) ++{ ++ assert(val != 0); ++# if (__GNUC__ >= 4) ++ return (unsigned)(__builtin_clzll(val)); ++# else ++ { ++ U32 mostSignificantWord = (U32)(val >> 32); ++ U32 leastSignificantWord = (U32)val; ++ if (mostSignificantWord == 0) { ++ return 32 + ZSTD_countLeadingZeros32(leastSignificantWord); ++ } else { ++ return ZSTD_countLeadingZeros32(mostSignificantWord); ++ } ++ } ++# endif ++} ++ ++MEM_STATIC unsigned ZSTD_NbCommonBytes(size_t val) ++{ ++ if (MEM_isLittleEndian()) { ++ if (MEM_64bits()) { ++ return ZSTD_countTrailingZeros64((U64)val) >> 3; ++ } else { ++ return ZSTD_countTrailingZeros32((U32)val) >> 3; ++ } ++ } else { /* Big Endian CPU */ ++ if (MEM_64bits()) { ++ return ZSTD_countLeadingZeros64((U64)val) >> 3; ++ } else { ++ return ZSTD_countLeadingZeros32((U32)val) >> 3; ++ } ++ } ++} ++ ++MEM_STATIC unsigned ZSTD_highbit32(U32 val) /* compress, dictBuilder, decodeCorpus */ ++{ ++ assert(val != 0); ++ return 31 - ZSTD_countLeadingZeros32(val); ++} ++ ++/* ZSTD_rotateRight_*(): ++ * Rotates a bitfield to the right by "count" bits. ++ * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts ++ */ ++MEM_STATIC ++U64 ZSTD_rotateRight_U64(U64 const value, U32 count) { ++ assert(count < 64); ++ count &= 0x3F; /* for fickle pattern recognition */ ++ return (value >> count) | (U64)(value << ((0U - count) & 0x3F)); ++} ++ ++MEM_STATIC ++U32 ZSTD_rotateRight_U32(U32 const value, U32 count) { ++ assert(count < 32); ++ count &= 0x1F; /* for fickle pattern recognition */ ++ return (value >> count) | (U32)(value << ((0U - count) & 0x1F)); ++} ++ ++MEM_STATIC ++U16 ZSTD_rotateRight_U16(U16 const value, U32 count) { ++ assert(count < 16); ++ count &= 0x0F; /* for fickle pattern recognition */ ++ return (value >> count) | (U16)(value << ((0U - count) & 0x0F)); ++} ++ ++#endif /* ZSTD_BITS_H */ +diff --git a/lib/zstd/common/bitstream.h b/lib/zstd/common/bitstream.h +index feef3a1b1d60..6a13f1f0f1e8 100644 +--- a/lib/zstd/common/bitstream.h ++++ b/lib/zstd/common/bitstream.h +@@ -1,7 +1,8 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* ****************************************************************** + * bitstream + * Part of FSE library +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -27,6 +28,7 @@ + #include "compiler.h" /* UNLIKELY() */ + #include "debug.h" /* assert(), DEBUGLOG(), RAWLOG() */ + #include "error_private.h" /* error codes and messages */ ++#include "bits.h" /* ZSTD_highbit32 */ + + + /*========================================= +@@ -79,19 +81,20 @@ MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC); + /*-******************************************** + * bitStream decoding API (read backward) + **********************************************/ ++typedef size_t BitContainerType; + typedef struct { +- size_t bitContainer; ++ BitContainerType bitContainer; + unsigned bitsConsumed; + const char* ptr; + const char* start; + const char* limitPtr; + } BIT_DStream_t; + +-typedef enum { BIT_DStream_unfinished = 0, +- BIT_DStream_endOfBuffer = 1, +- BIT_DStream_completed = 2, +- BIT_DStream_overflow = 3 } BIT_DStream_status; /* result of BIT_reloadDStream() */ +- /* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */ ++typedef enum { BIT_DStream_unfinished = 0, /* fully refilled */ ++ BIT_DStream_endOfBuffer = 1, /* still some bits left in bitstream */ ++ BIT_DStream_completed = 2, /* bitstream entirely consumed, bit-exact */ ++ BIT_DStream_overflow = 3 /* user requested more bits than present in bitstream */ ++ } BIT_DStream_status; /* result of BIT_reloadDStream() */ + + MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize); + MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits); +@@ -101,7 +104,7 @@ MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD); + + /* Start by invoking BIT_initDStream(). + * A chunk of the bitStream is then stored into a local register. +-* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t). ++* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (BitContainerType). + * You can then retrieve bitFields stored into the local register, **in reverse order**. + * Local register is explicitly reloaded from memory by the BIT_reloadDStream() method. + * A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished. +@@ -122,33 +125,6 @@ MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC); + MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits); + /* faster, but works only if nbBits >= 1 */ + +- +- +-/*-************************************************************** +-* Internal functions +-****************************************************************/ +-MEM_STATIC unsigned BIT_highbit32 (U32 val) +-{ +- assert(val != 0); +- { +-# if (__GNUC__ >= 3) /* Use GCC Intrinsic */ +- return __builtin_clz (val) ^ 31; +-# else /* Software version */ +- static const unsigned DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29, +- 11, 14, 16, 18, 22, 25, 3, 30, +- 8, 12, 20, 28, 15, 17, 24, 7, +- 19, 27, 23, 6, 26, 5, 4, 31 }; +- U32 v = val; +- v |= v >> 1; +- v |= v >> 2; +- v |= v >> 4; +- v |= v >> 8; +- v |= v >> 16; +- return DeBruijnClz[ (U32) (v * 0x07C4ACDDU) >> 27]; +-# endif +- } +-} +- + /*===== Local Constants =====*/ + static const unsigned BIT_mask[] = { + 0, 1, 3, 7, 0xF, 0x1F, +@@ -178,6 +154,12 @@ MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, + return 0; + } + ++FORCE_INLINE_TEMPLATE size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits) ++{ ++ assert(nbBits < BIT_MASK_SIZE); ++ return bitContainer & BIT_mask[nbBits]; ++} ++ + /*! BIT_addBits() : + * can add up to 31 bits into `bitC`. + * Note : does not check for register overflow ! */ +@@ -187,7 +169,7 @@ MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, + DEBUG_STATIC_ASSERT(BIT_MASK_SIZE == 32); + assert(nbBits < BIT_MASK_SIZE); + assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8); +- bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos; ++ bitC->bitContainer |= BIT_getLowerBits(value, nbBits) << bitC->bitPos; + bitC->bitPos += nbBits; + } + +@@ -266,35 +248,35 @@ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, si + bitD->ptr = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer); + bitD->bitContainer = MEM_readLEST(bitD->ptr); + { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1]; +- bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ ++ bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ + if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ } + } else { + bitD->ptr = bitD->start; + bitD->bitContainer = *(const BYTE*)(bitD->start); + switch(srcSize) + { +- case 7: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16); ++ case 7: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16); + ZSTD_FALLTHROUGH; + +- case 6: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24); ++ case 6: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24); + ZSTD_FALLTHROUGH; + +- case 5: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32); ++ case 5: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32); + ZSTD_FALLTHROUGH; + +- case 4: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[3]) << 24; ++ case 4: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[3]) << 24; + ZSTD_FALLTHROUGH; + +- case 3: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[2]) << 16; ++ case 3: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[2]) << 16; + ZSTD_FALLTHROUGH; + +- case 2: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[1]) << 8; ++ case 2: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[1]) << 8; + ZSTD_FALLTHROUGH; + + default: break; + } + { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1]; +- bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; ++ bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; + if (lastByte == 0) return ERROR(corruption_detected); /* endMark not present */ + } + bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8; +@@ -303,12 +285,12 @@ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, si + return srcSize; + } + +-MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getUpperBits(size_t bitContainer, U32 const start) ++FORCE_INLINE_TEMPLATE size_t BIT_getUpperBits(BitContainerType bitContainer, U32 const start) + { + return bitContainer >> start; + } + +-MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits) ++FORCE_INLINE_TEMPLATE size_t BIT_getMiddleBits(BitContainerType bitContainer, U32 const start, U32 const nbBits) + { + U32 const regMask = sizeof(bitContainer)*8 - 1; + /* if start > regMask, bitstream is corrupted, and result is undefined */ +@@ -325,19 +307,13 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 c + #endif + } + +-MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits) +-{ +- assert(nbBits < BIT_MASK_SIZE); +- return bitContainer & BIT_mask[nbBits]; +-} +- + /*! BIT_lookBits() : + * Provides next n bits from local register. + * local register is not modified. + * On 32-bits, maxNbBits==24. + * On 64-bits, maxNbBits==56. + * @return : value extracted */ +-MEM_STATIC FORCE_INLINE_ATTR size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits) ++FORCE_INLINE_TEMPLATE size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits) + { + /* arbitrate between double-shift and shift+mask */ + #if 1 +@@ -360,7 +336,7 @@ MEM_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits) + return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask); + } + +-MEM_STATIC FORCE_INLINE_ATTR void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) ++FORCE_INLINE_TEMPLATE void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) + { + bitD->bitsConsumed += nbBits; + } +@@ -369,7 +345,7 @@ MEM_STATIC FORCE_INLINE_ATTR void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) + * Read (consume) next n bits from local register and update. + * Pay attention to not read more than nbBits contained into local register. + * @return : extracted value. */ +-MEM_STATIC FORCE_INLINE_ATTR size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits) ++FORCE_INLINE_TEMPLATE size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits) + { + size_t const value = BIT_lookBits(bitD, nbBits); + BIT_skipBits(bitD, nbBits); +@@ -377,7 +353,7 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_readBits(BIT_DStream_t* bitD, unsigned n + } + + /*! BIT_readBitsFast() : +- * unsafe version; only works only if nbBits >= 1 */ ++ * unsafe version; only works if nbBits >= 1 */ + MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits) + { + size_t const value = BIT_lookBitsFast(bitD, nbBits); +@@ -386,6 +362,21 @@ MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits) + return value; + } + ++/*! BIT_reloadDStream_internal() : ++ * Simple variant of BIT_reloadDStream(), with two conditions: ++ * 1. bitstream is valid : bitsConsumed <= sizeof(bitD->bitContainer)*8 ++ * 2. look window is valid after shifted down : bitD->ptr >= bitD->start ++ */ ++MEM_STATIC BIT_DStream_status BIT_reloadDStream_internal(BIT_DStream_t* bitD) ++{ ++ assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8); ++ bitD->ptr -= bitD->bitsConsumed >> 3; ++ assert(bitD->ptr >= bitD->start); ++ bitD->bitsConsumed &= 7; ++ bitD->bitContainer = MEM_readLEST(bitD->ptr); ++ return BIT_DStream_unfinished; ++} ++ + /*! BIT_reloadDStreamFast() : + * Similar to BIT_reloadDStream(), but with two differences: + * 1. bitsConsumed <= sizeof(bitD->bitContainer)*8 must hold! +@@ -396,31 +387,35 @@ MEM_STATIC BIT_DStream_status BIT_reloadDStreamFast(BIT_DStream_t* bitD) + { + if (UNLIKELY(bitD->ptr < bitD->limitPtr)) + return BIT_DStream_overflow; +- assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8); +- bitD->ptr -= bitD->bitsConsumed >> 3; +- bitD->bitsConsumed &= 7; +- bitD->bitContainer = MEM_readLEST(bitD->ptr); +- return BIT_DStream_unfinished; ++ return BIT_reloadDStream_internal(bitD); + } + + /*! BIT_reloadDStream() : + * Refill `bitD` from buffer previously set in BIT_initDStream() . +- * This function is safe, it guarantees it will not read beyond src buffer. ++ * This function is safe, it guarantees it will not never beyond src buffer. + * @return : status of `BIT_DStream_t` internal register. + * when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */ +-MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD) ++FORCE_INLINE_TEMPLATE BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD) + { +- if (bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8)) /* overflow detected, like end of stream */ ++ /* note : once in overflow mode, a bitstream remains in this mode until it's reset */ ++ if (UNLIKELY(bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8))) { ++ static const BitContainerType zeroFilled = 0; ++ bitD->ptr = (const char*)&zeroFilled; /* aliasing is allowed for char */ ++ /* overflow detected, erroneous scenario or end of stream: no update */ + return BIT_DStream_overflow; ++ } ++ ++ assert(bitD->ptr >= bitD->start); + + if (bitD->ptr >= bitD->limitPtr) { +- return BIT_reloadDStreamFast(bitD); ++ return BIT_reloadDStream_internal(bitD); + } + if (bitD->ptr == bitD->start) { ++ /* reached end of bitStream => no update */ + if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer; + return BIT_DStream_completed; + } +- /* start < ptr < limitPtr */ ++ /* start < ptr < limitPtr => cautious update */ + { U32 nbBytes = bitD->bitsConsumed >> 3; + BIT_DStream_status result = BIT_DStream_unfinished; + if (bitD->ptr - nbBytes < bitD->start) { +diff --git a/lib/zstd/common/compiler.h b/lib/zstd/common/compiler.h +index c42d39faf9bd..508ee25537bb 100644 +--- a/lib/zstd/common/compiler.h ++++ b/lib/zstd/common/compiler.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -11,6 +12,8 @@ + #ifndef ZSTD_COMPILER_H + #define ZSTD_COMPILER_H + ++#include ++ + #include "portability_macros.h" + + /*-******************************************************* +@@ -41,12 +44,15 @@ + */ + #define WIN_CDECL + ++/* UNUSED_ATTR tells the compiler it is okay if the function is unused. */ ++#define UNUSED_ATTR __attribute__((unused)) ++ + /* + * FORCE_INLINE_TEMPLATE is used to define C "templates", which take constant + * parameters. They must be inlined for the compiler to eliminate the constant + * branches. + */ +-#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR ++#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR UNUSED_ATTR + /* + * HINT_INLINE is used to help the compiler generate better code. It is *not* + * used for "templates", so it can be tweaked based on the compilers +@@ -61,11 +67,21 @@ + #if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 4 && __GNUC_MINOR__ >= 8 && __GNUC__ < 5 + # define HINT_INLINE static INLINE_KEYWORD + #else +-# define HINT_INLINE static INLINE_KEYWORD FORCE_INLINE_ATTR ++# define HINT_INLINE FORCE_INLINE_TEMPLATE + #endif + +-/* UNUSED_ATTR tells the compiler it is okay if the function is unused. */ +-#define UNUSED_ATTR __attribute__((unused)) ++/* "soft" inline : ++ * The compiler is free to select if it's a good idea to inline or not. ++ * The main objective is to silence compiler warnings ++ * when a defined function in included but not used. ++ * ++ * Note : this macro is prefixed `MEM_` because it used to be provided by `mem.h` unit. ++ * Updating the prefix is probably preferable, but requires a fairly large codemod, ++ * since this name is used everywhere. ++ */ ++#ifndef MEM_STATIC /* already defined in Linux Kernel mem.h */ ++#define MEM_STATIC static __inline UNUSED_ATTR ++#endif + + /* force no inlining */ + #define FORCE_NOINLINE static __attribute__((__noinline__)) +@@ -86,23 +102,24 @@ + # define PREFETCH_L1(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) + # define PREFETCH_L2(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 2 /* locality */) + #elif defined(__aarch64__) +-# define PREFETCH_L1(ptr) __asm__ __volatile__("prfm pldl1keep, %0" ::"Q"(*(ptr))) +-# define PREFETCH_L2(ptr) __asm__ __volatile__("prfm pldl2keep, %0" ::"Q"(*(ptr))) ++# define PREFETCH_L1(ptr) do { __asm__ __volatile__("prfm pldl1keep, %0" ::"Q"(*(ptr))); } while (0) ++# define PREFETCH_L2(ptr) do { __asm__ __volatile__("prfm pldl2keep, %0" ::"Q"(*(ptr))); } while (0) + #else +-# define PREFETCH_L1(ptr) (void)(ptr) /* disabled */ +-# define PREFETCH_L2(ptr) (void)(ptr) /* disabled */ ++# define PREFETCH_L1(ptr) do { (void)(ptr); } while (0) /* disabled */ ++# define PREFETCH_L2(ptr) do { (void)(ptr); } while (0) /* disabled */ + #endif /* NO_PREFETCH */ + + #define CACHELINE_SIZE 64 + +-#define PREFETCH_AREA(p, s) { \ +- const char* const _ptr = (const char*)(p); \ +- size_t const _size = (size_t)(s); \ +- size_t _pos; \ +- for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \ +- PREFETCH_L2(_ptr + _pos); \ +- } \ +-} ++#define PREFETCH_AREA(p, s) \ ++ do { \ ++ const char* const _ptr = (const char*)(p); \ ++ size_t const _size = (size_t)(s); \ ++ size_t _pos; \ ++ for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \ ++ PREFETCH_L2(_ptr + _pos); \ ++ } \ ++ } while (0) + + /* vectorization + * older GCC (pre gcc-4.3 picked as the cutoff) uses a different syntax, +@@ -126,9 +143,9 @@ + #define UNLIKELY(x) (__builtin_expect((x), 0)) + + #if __has_builtin(__builtin_unreachable) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5))) +-# define ZSTD_UNREACHABLE { assert(0), __builtin_unreachable(); } ++# define ZSTD_UNREACHABLE do { assert(0), __builtin_unreachable(); } while (0) + #else +-# define ZSTD_UNREACHABLE { assert(0); } ++# define ZSTD_UNREACHABLE do { assert(0); } while (0) + #endif + + /* disable warnings */ +@@ -179,6 +196,85 @@ + * Sanitizer + *****************************************************************/ + ++/* ++ * Zstd relies on pointer overflow in its decompressor. ++ * We add this attribute to functions that rely on pointer overflow. ++ */ ++#ifndef ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++# if __has_attribute(no_sanitize) ++# if !defined(__clang__) && defined(__GNUC__) && __GNUC__ < 8 ++ /* gcc < 8 only has signed-integer-overlow which triggers on pointer overflow */ ++# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR __attribute__((no_sanitize("signed-integer-overflow"))) ++# else ++ /* older versions of clang [3.7, 5.0) will warn that pointer-overflow is ignored. */ ++# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR __attribute__((no_sanitize("pointer-overflow"))) ++# endif ++# else ++# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++# endif ++#endif ++ ++/* ++ * Helper function to perform a wrapped pointer difference without trigging ++ * UBSAN. ++ * ++ * @returns lhs - rhs with wrapping ++ */ ++MEM_STATIC ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++ptrdiff_t ZSTD_wrappedPtrDiff(unsigned char const* lhs, unsigned char const* rhs) ++{ ++ return lhs - rhs; ++} ++ ++/* ++ * Helper function to perform a wrapped pointer add without triggering UBSAN. ++ * ++ * @return ptr + add with wrapping ++ */ ++MEM_STATIC ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++unsigned char const* ZSTD_wrappedPtrAdd(unsigned char const* ptr, ptrdiff_t add) ++{ ++ return ptr + add; ++} ++ ++/* ++ * Helper function to perform a wrapped pointer subtraction without triggering ++ * UBSAN. ++ * ++ * @return ptr - sub with wrapping ++ */ ++MEM_STATIC ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++unsigned char const* ZSTD_wrappedPtrSub(unsigned char const* ptr, ptrdiff_t sub) ++{ ++ return ptr - sub; ++} ++ ++/* ++ * Helper function to add to a pointer that works around C's undefined behavior ++ * of adding 0 to NULL. ++ * ++ * @returns `ptr + add` except it defines `NULL + 0 == NULL`. ++ */ ++MEM_STATIC ++unsigned char* ZSTD_maybeNullPtrAdd(unsigned char* ptr, ptrdiff_t add) ++{ ++ return add > 0 ? ptr + add : ptr; ++} ++ ++/* Issue #3240 reports an ASAN failure on an llvm-mingw build. Out of an ++ * abundance of caution, disable our custom poisoning on mingw. */ ++#ifdef __MINGW32__ ++#ifndef ZSTD_ASAN_DONT_POISON_WORKSPACE ++#define ZSTD_ASAN_DONT_POISON_WORKSPACE 1 ++#endif ++#ifndef ZSTD_MSAN_DONT_POISON_WORKSPACE ++#define ZSTD_MSAN_DONT_POISON_WORKSPACE 1 ++#endif ++#endif ++ + + + #endif /* ZSTD_COMPILER_H */ +diff --git a/lib/zstd/common/cpu.h b/lib/zstd/common/cpu.h +index 0db7b42407ee..d8319a2bef4c 100644 +--- a/lib/zstd/common/cpu.h ++++ b/lib/zstd/common/cpu.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/common/debug.c b/lib/zstd/common/debug.c +index bb863c9ea616..8eb6aa9a3b20 100644 +--- a/lib/zstd/common/debug.c ++++ b/lib/zstd/common/debug.c +@@ -1,7 +1,8 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* ****************************************************************** + * debug + * Part of FSE library +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -21,4 +22,10 @@ + + #include "debug.h" + ++#if (DEBUGLEVEL>=2) ++/* We only use this when DEBUGLEVEL>=2, but we get -Werror=pedantic errors if a ++ * translation unit is empty. So remove this from Linux kernel builds, but ++ * otherwise just leave it in. ++ */ + int g_debuglevel = DEBUGLEVEL; ++#endif +diff --git a/lib/zstd/common/debug.h b/lib/zstd/common/debug.h +index 6dd88d1fbd02..226ba3c57ec3 100644 +--- a/lib/zstd/common/debug.h ++++ b/lib/zstd/common/debug.h +@@ -1,7 +1,8 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* ****************************************************************** + * debug + * Part of FSE library +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -82,18 +83,27 @@ extern int g_debuglevel; /* the variable is only declared, + It's useful when enabling very verbose levels + on selective conditions (such as position in src) */ + +-# define RAWLOG(l, ...) { \ +- if (l<=g_debuglevel) { \ +- ZSTD_DEBUG_PRINT(__VA_ARGS__); \ +- } } +-# define DEBUGLOG(l, ...) { \ +- if (l<=g_debuglevel) { \ +- ZSTD_DEBUG_PRINT(__FILE__ ": " __VA_ARGS__); \ +- ZSTD_DEBUG_PRINT(" \n"); \ +- } } ++# define RAWLOG(l, ...) \ ++ do { \ ++ if (l<=g_debuglevel) { \ ++ ZSTD_DEBUG_PRINT(__VA_ARGS__); \ ++ } \ ++ } while (0) ++ ++#define STRINGIFY(x) #x ++#define TOSTRING(x) STRINGIFY(x) ++#define LINE_AS_STRING TOSTRING(__LINE__) ++ ++# define DEBUGLOG(l, ...) \ ++ do { \ ++ if (l<=g_debuglevel) { \ ++ ZSTD_DEBUG_PRINT(__FILE__ ":" LINE_AS_STRING ": " __VA_ARGS__); \ ++ ZSTD_DEBUG_PRINT(" \n"); \ ++ } \ ++ } while (0) + #else +-# define RAWLOG(l, ...) {} /* disabled */ +-# define DEBUGLOG(l, ...) {} /* disabled */ ++# define RAWLOG(l, ...) do { } while (0) /* disabled */ ++# define DEBUGLOG(l, ...) do { } while (0) /* disabled */ + #endif + + +diff --git a/lib/zstd/common/entropy_common.c b/lib/zstd/common/entropy_common.c +index fef67056f052..6cdd82233fb5 100644 +--- a/lib/zstd/common/entropy_common.c ++++ b/lib/zstd/common/entropy_common.c +@@ -1,6 +1,7 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* ****************************************************************** + * Common functions of New Generation Entropy library +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -19,8 +20,8 @@ + #include "error_private.h" /* ERR_*, ERROR */ + #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ + #include "fse.h" +-#define HUF_STATIC_LINKING_ONLY /* HUF_TABLELOG_ABSOLUTEMAX */ + #include "huf.h" ++#include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ + + + /*=== Version ===*/ +@@ -38,23 +39,6 @@ const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } + /*-************************************************************** + * FSE NCount encoding-decoding + ****************************************************************/ +-static U32 FSE_ctz(U32 val) +-{ +- assert(val != 0); +- { +-# if (__GNUC__ >= 3) /* GCC Intrinsic */ +- return __builtin_ctz(val); +-# else /* Software version */ +- U32 count = 0; +- while ((val & 1) == 0) { +- val >>= 1; +- ++count; +- } +- return count; +-# endif +- } +-} +- + FORCE_INLINE_TEMPLATE + size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, + const void* headerBuffer, size_t hbSize) +@@ -102,7 +86,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne + * repeat. + * Avoid UB by setting the high bit to 1. + */ +- int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; ++ int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; + while (repeats >= 12) { + charnum += 3 * 12; + if (LIKELY(ip <= iend-7)) { +@@ -113,7 +97,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne + ip = iend - 4; + } + bitStream = MEM_readLE32(ip) >> bitCount; +- repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; ++ repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; + } + charnum += 3 * repeats; + bitStream >>= 2 * repeats; +@@ -178,7 +162,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne + * know that threshold > 1. + */ + if (remaining <= 1) break; +- nbBits = BIT_highbit32(remaining) + 1; ++ nbBits = ZSTD_highbit32(remaining) + 1; + threshold = 1 << (nbBits - 1); + } + if (charnum >= maxSV1) break; +@@ -253,7 +237,7 @@ size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, + const void* src, size_t srcSize) + { + U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; +- return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* bmi2 */ 0); ++ return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); + } + + FORCE_INLINE_TEMPLATE size_t +@@ -301,14 +285,14 @@ HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, + if (weightTotal == 0) return ERROR(corruption_detected); + + /* get last non-null symbol weight (implied, total must be 2^n) */ +- { U32 const tableLog = BIT_highbit32(weightTotal) + 1; ++ { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; + if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); + *tableLogPtr = tableLog; + /* determine last weight */ + { U32 const total = 1 << tableLog; + U32 const rest = total - weightTotal; +- U32 const verif = 1 << BIT_highbit32(rest); +- U32 const lastWeight = BIT_highbit32(rest) + 1; ++ U32 const verif = 1 << ZSTD_highbit32(rest); ++ U32 const lastWeight = ZSTD_highbit32(rest) + 1; + if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ + huffWeight[oSize] = (BYTE)lastWeight; + rankStats[lastWeight]++; +@@ -345,13 +329,13 @@ size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, + U32* nbSymbolsPtr, U32* tableLogPtr, + const void* src, size_t srcSize, + void* workSpace, size_t wkspSize, +- int bmi2) ++ int flags) + { + #if DYNAMIC_BMI2 +- if (bmi2) { ++ if (flags & HUF_flags_bmi2) { + return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); + } + #endif +- (void)bmi2; ++ (void)flags; + return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); + } +diff --git a/lib/zstd/common/error_private.c b/lib/zstd/common/error_private.c +index 6d1135f8c373..a4062d30d170 100644 +--- a/lib/zstd/common/error_private.c ++++ b/lib/zstd/common/error_private.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -27,9 +28,11 @@ const char* ERR_getErrorString(ERR_enum code) + case PREFIX(version_unsupported): return "Version not supported"; + case PREFIX(frameParameter_unsupported): return "Unsupported frame parameter"; + case PREFIX(frameParameter_windowTooLarge): return "Frame requires too much memory for decoding"; +- case PREFIX(corruption_detected): return "Corrupted block detected"; ++ case PREFIX(corruption_detected): return "Data corruption detected"; + case PREFIX(checksum_wrong): return "Restored data doesn't match checksum"; ++ case PREFIX(literals_headerWrong): return "Header of Literals' block doesn't respect format specification"; + case PREFIX(parameter_unsupported): return "Unsupported parameter"; ++ case PREFIX(parameter_combination_unsupported): return "Unsupported combination of parameters"; + case PREFIX(parameter_outOfBound): return "Parameter is out of bound"; + case PREFIX(init_missing): return "Context should be init first"; + case PREFIX(memory_allocation): return "Allocation error : not enough memory"; +@@ -38,17 +41,22 @@ const char* ERR_getErrorString(ERR_enum code) + case PREFIX(tableLog_tooLarge): return "tableLog requires too much memory : unsupported"; + case PREFIX(maxSymbolValue_tooLarge): return "Unsupported max Symbol Value : too large"; + case PREFIX(maxSymbolValue_tooSmall): return "Specified maxSymbolValue is too small"; ++ case PREFIX(stabilityCondition_notRespected): return "pledged buffer stability condition is not respected"; + case PREFIX(dictionary_corrupted): return "Dictionary is corrupted"; + case PREFIX(dictionary_wrong): return "Dictionary mismatch"; + case PREFIX(dictionaryCreation_failed): return "Cannot create Dictionary from provided samples"; + case PREFIX(dstSize_tooSmall): return "Destination buffer is too small"; + case PREFIX(srcSize_wrong): return "Src size is incorrect"; + case PREFIX(dstBuffer_null): return "Operation on NULL destination buffer"; ++ case PREFIX(noForwardProgress_destFull): return "Operation made no progress over multiple calls, due to output buffer being full"; ++ case PREFIX(noForwardProgress_inputEmpty): return "Operation made no progress over multiple calls, due to input being empty"; + /* following error codes are not stable and may be removed or changed in a future version */ + case PREFIX(frameIndex_tooLarge): return "Frame index is too large"; + case PREFIX(seekableIO): return "An I/O error occurred when reading/seeking"; + case PREFIX(dstBuffer_wrong): return "Destination buffer is wrong"; + case PREFIX(srcBuffer_wrong): return "Source buffer is wrong"; ++ case PREFIX(sequenceProducer_failed): return "Block-level external sequence producer returned an error code"; ++ case PREFIX(externalSequences_invalid): return "External sequences are not valid"; + case PREFIX(maxCode): + default: return notErrorCode; + } +diff --git a/lib/zstd/common/error_private.h b/lib/zstd/common/error_private.h +index ca5101e542fa..0410ca415b54 100644 +--- a/lib/zstd/common/error_private.h ++++ b/lib/zstd/common/error_private.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -49,8 +50,13 @@ ERR_STATIC unsigned ERR_isError(size_t code) { return (code > ERROR(maxCode)); } + ERR_STATIC ERR_enum ERR_getErrorCode(size_t code) { if (!ERR_isError(code)) return (ERR_enum)0; return (ERR_enum) (0-code); } + + /* check and forward error code */ +-#define CHECK_V_F(e, f) size_t const e = f; if (ERR_isError(e)) return e +-#define CHECK_F(f) { CHECK_V_F(_var_err__, f); } ++#define CHECK_V_F(e, f) \ ++ size_t const e = f; \ ++ do { \ ++ if (ERR_isError(e)) \ ++ return e; \ ++ } while (0) ++#define CHECK_F(f) do { CHECK_V_F(_var_err__, f); } while (0) + + + /*-**************************************** +@@ -84,10 +90,12 @@ void _force_has_format_string(const char *format, ...) { + * We want to force this function invocation to be syntactically correct, but + * we don't want to force runtime evaluation of its arguments. + */ +-#define _FORCE_HAS_FORMAT_STRING(...) \ +- if (0) { \ +- _force_has_format_string(__VA_ARGS__); \ +- } ++#define _FORCE_HAS_FORMAT_STRING(...) \ ++ do { \ ++ if (0) { \ ++ _force_has_format_string(__VA_ARGS__); \ ++ } \ ++ } while (0) + + #define ERR_QUOTE(str) #str + +@@ -98,48 +106,50 @@ void _force_has_format_string(const char *format, ...) { + * In order to do that (particularly, printing the conditional that failed), + * this can't just wrap RETURN_ERROR(). + */ +-#define RETURN_ERROR_IF(cond, err, ...) \ +- if (cond) { \ +- RAWLOG(3, "%s:%d: ERROR!: check %s failed, returning %s", \ +- __FILE__, __LINE__, ERR_QUOTE(cond), ERR_QUOTE(ERROR(err))); \ +- _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ +- RAWLOG(3, ": " __VA_ARGS__); \ +- RAWLOG(3, "\n"); \ +- return ERROR(err); \ +- } ++#define RETURN_ERROR_IF(cond, err, ...) \ ++ do { \ ++ if (cond) { \ ++ RAWLOG(3, "%s:%d: ERROR!: check %s failed, returning %s", \ ++ __FILE__, __LINE__, ERR_QUOTE(cond), ERR_QUOTE(ERROR(err))); \ ++ _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ ++ RAWLOG(3, ": " __VA_ARGS__); \ ++ RAWLOG(3, "\n"); \ ++ return ERROR(err); \ ++ } \ ++ } while (0) + + /* + * Unconditionally return the specified error. + * + * In debug modes, prints additional information. + */ +-#define RETURN_ERROR(err, ...) \ +- do { \ +- RAWLOG(3, "%s:%d: ERROR!: unconditional check failed, returning %s", \ +- __FILE__, __LINE__, ERR_QUOTE(ERROR(err))); \ +- _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ +- RAWLOG(3, ": " __VA_ARGS__); \ +- RAWLOG(3, "\n"); \ +- return ERROR(err); \ +- } while(0); ++#define RETURN_ERROR(err, ...) \ ++ do { \ ++ RAWLOG(3, "%s:%d: ERROR!: unconditional check failed, returning %s", \ ++ __FILE__, __LINE__, ERR_QUOTE(ERROR(err))); \ ++ _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ ++ RAWLOG(3, ": " __VA_ARGS__); \ ++ RAWLOG(3, "\n"); \ ++ return ERROR(err); \ ++ } while(0) + + /* + * If the provided expression evaluates to an error code, returns that error code. + * + * In debug modes, prints additional information. + */ +-#define FORWARD_IF_ERROR(err, ...) \ +- do { \ +- size_t const err_code = (err); \ +- if (ERR_isError(err_code)) { \ +- RAWLOG(3, "%s:%d: ERROR!: forwarding error in %s: %s", \ +- __FILE__, __LINE__, ERR_QUOTE(err), ERR_getErrorName(err_code)); \ +- _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ +- RAWLOG(3, ": " __VA_ARGS__); \ +- RAWLOG(3, "\n"); \ +- return err_code; \ +- } \ +- } while(0); ++#define FORWARD_IF_ERROR(err, ...) \ ++ do { \ ++ size_t const err_code = (err); \ ++ if (ERR_isError(err_code)) { \ ++ RAWLOG(3, "%s:%d: ERROR!: forwarding error in %s: %s", \ ++ __FILE__, __LINE__, ERR_QUOTE(err), ERR_getErrorName(err_code)); \ ++ _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ ++ RAWLOG(3, ": " __VA_ARGS__); \ ++ RAWLOG(3, "\n"); \ ++ return err_code; \ ++ } \ ++ } while(0) + + + #endif /* ERROR_H_MODULE */ +diff --git a/lib/zstd/common/fse.h b/lib/zstd/common/fse.h +index 4507043b2287..2185a578617d 100644 +--- a/lib/zstd/common/fse.h ++++ b/lib/zstd/common/fse.h +@@ -1,7 +1,8 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* ****************************************************************** + * FSE : Finite State Entropy codec + * Public Prototypes declaration +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -50,34 +51,6 @@ + FSE_PUBLIC_API unsigned FSE_versionNumber(void); /*< library version number; to be used when checking dll version */ + + +-/*-**************************************** +-* FSE simple functions +-******************************************/ +-/*! FSE_compress() : +- Compress content of buffer 'src', of size 'srcSize', into destination buffer 'dst'. +- 'dst' buffer must be already allocated. Compression runs faster is dstCapacity >= FSE_compressBound(srcSize). +- @return : size of compressed data (<= dstCapacity). +- Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!! +- if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression instead. +- if FSE_isError(return), compression failed (more details using FSE_getErrorName()) +-*/ +-FSE_PUBLIC_API size_t FSE_compress(void* dst, size_t dstCapacity, +- const void* src, size_t srcSize); +- +-/*! FSE_decompress(): +- Decompress FSE data from buffer 'cSrc', of size 'cSrcSize', +- into already allocated destination buffer 'dst', of size 'dstCapacity'. +- @return : size of regenerated data (<= maxDstSize), +- or an error code, which can be tested using FSE_isError() . +- +- ** Important ** : FSE_decompress() does not decompress non-compressible nor RLE data !!! +- Why ? : making this distinction requires a header. +- Header management is intentionally delegated to the user layer, which can better manage special cases. +-*/ +-FSE_PUBLIC_API size_t FSE_decompress(void* dst, size_t dstCapacity, +- const void* cSrc, size_t cSrcSize); +- +- + /*-***************************************** + * Tool functions + ******************************************/ +@@ -88,20 +61,6 @@ FSE_PUBLIC_API unsigned FSE_isError(size_t code); /* tells if a return + FSE_PUBLIC_API const char* FSE_getErrorName(size_t code); /* provides error code string (useful for debugging) */ + + +-/*-***************************************** +-* FSE advanced functions +-******************************************/ +-/*! FSE_compress2() : +- Same as FSE_compress(), but allows the selection of 'maxSymbolValue' and 'tableLog' +- Both parameters can be defined as '0' to mean : use default value +- @return : size of compressed data +- Special values : if return == 0, srcData is not compressible => Nothing is stored within cSrc !!! +- if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression. +- if FSE_isError(return), it's an error code. +-*/ +-FSE_PUBLIC_API size_t FSE_compress2 (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog); +- +- + /*-***************************************** + * FSE detailed API + ******************************************/ +@@ -161,8 +120,6 @@ FSE_PUBLIC_API size_t FSE_writeNCount (void* buffer, size_t bufferSize, + /*! Constructor and Destructor of FSE_CTable. + Note that FSE_CTable size depends on 'tableLog' and 'maxSymbolValue' */ + typedef unsigned FSE_CTable; /* don't allocate that. It's only meant to be more restrictive than void* */ +-FSE_PUBLIC_API FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog); +-FSE_PUBLIC_API void FSE_freeCTable (FSE_CTable* ct); + + /*! FSE_buildCTable(): + Builds `ct`, which must be already allocated, using FSE_createCTable(). +@@ -238,23 +195,7 @@ FSE_PUBLIC_API size_t FSE_readNCount_bmi2(short* normalizedCounter, + unsigned* maxSymbolValuePtr, unsigned* tableLogPtr, + const void* rBuffer, size_t rBuffSize, int bmi2); + +-/*! Constructor and Destructor of FSE_DTable. +- Note that its size depends on 'tableLog' */ + typedef unsigned FSE_DTable; /* don't allocate that. It's just a way to be more restrictive than void* */ +-FSE_PUBLIC_API FSE_DTable* FSE_createDTable(unsigned tableLog); +-FSE_PUBLIC_API void FSE_freeDTable(FSE_DTable* dt); +- +-/*! FSE_buildDTable(): +- Builds 'dt', which must be already allocated, using FSE_createDTable(). +- return : 0, or an errorCode, which can be tested using FSE_isError() */ +-FSE_PUBLIC_API size_t FSE_buildDTable (FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog); +- +-/*! FSE_decompress_usingDTable(): +- Decompress compressed source `cSrc` of size `cSrcSize` using `dt` +- into `dst` which must be already allocated. +- @return : size of regenerated data (necessarily <= `dstCapacity`), +- or an errorCode, which can be tested using FSE_isError() */ +-FSE_PUBLIC_API size_t FSE_decompress_usingDTable(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, const FSE_DTable* dt); + + /*! + Tutorial : +@@ -286,6 +227,7 @@ If there is an error, the function will return an error code, which can be teste + + #endif /* FSE_H */ + ++ + #if !defined(FSE_H_FSE_STATIC_LINKING_ONLY) + #define FSE_H_FSE_STATIC_LINKING_ONLY + +@@ -317,16 +259,6 @@ If there is an error, the function will return an error code, which can be teste + unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus); + /*< same as FSE_optimalTableLog(), which used `minus==2` */ + +-/* FSE_compress_wksp() : +- * Same as FSE_compress2(), but using an externally allocated scratch buffer (`workSpace`). +- * FSE_COMPRESS_WKSP_SIZE_U32() provides the minimum size required for `workSpace` as a table of FSE_CTable. +- */ +-#define FSE_COMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) ( FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) + ((maxTableLog > 12) ? (1 << (maxTableLog - 2)) : 1024) ) +-size_t FSE_compress_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); +- +-size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits); +-/*< build a fake FSE_CTable, designed for a flat distribution, where each symbol uses nbBits */ +- + size_t FSE_buildCTable_rle (FSE_CTable* ct, unsigned char symbolValue); + /*< build a fake FSE_CTable, designed to compress always the same symbolValue */ + +@@ -344,19 +276,11 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsi + FSE_PUBLIC_API size_t FSE_buildDTable_wksp(FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); + /*< Same as FSE_buildDTable(), using an externally allocated `workspace` produced with `FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxSymbolValue)` */ + +-size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits); +-/*< build a fake FSE_DTable, designed to read a flat distribution where each symbol uses nbBits */ +- +-size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue); +-/*< build a fake FSE_DTable, designed to always generate the same symbolValue */ +- +-#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1) ++#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + 1 + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1) + #define FSE_DECOMPRESS_WKSP_SIZE(maxTableLog, maxSymbolValue) (FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(unsigned)) +-size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize); +-/*< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)` */ +- + size_t FSE_decompress_wksp_bmi2(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize, int bmi2); +-/*< Same as FSE_decompress_wksp() but with dynamic BMI2 support. Pass 1 if your CPU supports BMI2 or 0 if it doesn't. */ ++/*< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)`. ++ * Set bmi2 to 1 if your CPU supports BMI2 or 0 if it doesn't */ + + typedef enum { + FSE_repeat_none, /*< Cannot use the previous table */ +@@ -539,20 +463,20 @@ MEM_STATIC void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* statePtr, un + FSE_symbolCompressionTransform const symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol]; + const U16* const stateTable = (const U16*)(statePtr->stateTable); + U32 const nbBitsOut = (U32)((statePtr->value + symbolTT.deltaNbBits) >> 16); +- BIT_addBits(bitC, statePtr->value, nbBitsOut); ++ BIT_addBits(bitC, (size_t)statePtr->value, nbBitsOut); + statePtr->value = stateTable[ (statePtr->value >> nbBitsOut) + symbolTT.deltaFindState]; + } + + MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePtr) + { +- BIT_addBits(bitC, statePtr->value, statePtr->stateLog); ++ BIT_addBits(bitC, (size_t)statePtr->value, statePtr->stateLog); + BIT_flushBits(bitC); + } + + + /* FSE_getMaxNbBits() : + * Approximate maximum cost of a symbol, in bits. +- * Fractional get rounded up (i.e : a symbol with a normalized frequency of 3 gives the same result as a frequency of 2) ++ * Fractional get rounded up (i.e. a symbol with a normalized frequency of 3 gives the same result as a frequency of 2) + * note 1 : assume symbolValue is valid (<= maxSymbolValue) + * note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits */ + MEM_STATIC U32 FSE_getMaxNbBits(const void* symbolTTPtr, U32 symbolValue) +diff --git a/lib/zstd/common/fse_decompress.c b/lib/zstd/common/fse_decompress.c +index 8dcb8ca39767..3a17e84f27bf 100644 +--- a/lib/zstd/common/fse_decompress.c ++++ b/lib/zstd/common/fse_decompress.c +@@ -1,6 +1,7 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* ****************************************************************** + * FSE : Finite State Entropy decoder +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -22,8 +23,8 @@ + #define FSE_STATIC_LINKING_ONLY + #include "fse.h" + #include "error_private.h" +-#define ZSTD_DEPS_NEED_MALLOC +-#include "zstd_deps.h" ++#include "zstd_deps.h" /* ZSTD_memcpy */ ++#include "bits.h" /* ZSTD_highbit32 */ + + + /* ************************************************************** +@@ -55,19 +56,6 @@ + #define FSE_FUNCTION_NAME(X,Y) FSE_CAT(X,Y) + #define FSE_TYPE_NAME(X,Y) FSE_CAT(X,Y) + +- +-/* Function templates */ +-FSE_DTable* FSE_createDTable (unsigned tableLog) +-{ +- if (tableLog > FSE_TABLELOG_ABSOLUTE_MAX) tableLog = FSE_TABLELOG_ABSOLUTE_MAX; +- return (FSE_DTable*)ZSTD_malloc( FSE_DTABLE_SIZE_U32(tableLog) * sizeof (U32) ); +-} +- +-void FSE_freeDTable (FSE_DTable* dt) +-{ +- ZSTD_free(dt); +-} +- + static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize) + { + void* const tdPtr = dt+1; /* because *dt is unsigned, 32-bits aligned on 32-bits */ +@@ -96,7 +84,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo + symbolNext[s] = 1; + } else { + if (normalizedCounter[s] >= largeLimit) DTableH.fastMode=0; +- symbolNext[s] = normalizedCounter[s]; ++ symbolNext[s] = (U16)normalizedCounter[s]; + } } } + ZSTD_memcpy(dt, &DTableH, sizeof(DTableH)); + } +@@ -111,8 +99,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo + * all symbols have counts <= 8. We ensure we have 8 bytes at the end of + * our buffer to handle the over-write. + */ +- { +- U64 const add = 0x0101010101010101ull; ++ { U64 const add = 0x0101010101010101ull; + size_t pos = 0; + U64 sv = 0; + U32 s; +@@ -123,14 +110,13 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo + for (i = 8; i < n; i += 8) { + MEM_write64(spread + pos + i, sv); + } +- pos += n; +- } +- } ++ pos += (size_t)n; ++ } } + /* Now we spread those positions across the table. +- * The benefit of doing it in two stages is that we avoid the the ++ * The benefit of doing it in two stages is that we avoid the + * variable size inner loop, which caused lots of branch misses. + * Now we can run through all the positions without any branch misses. +- * We unroll the loop twice, since that is what emperically worked best. ++ * We unroll the loop twice, since that is what empirically worked best. + */ + { + size_t position = 0; +@@ -166,7 +152,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo + for (u=0; utableLog = 0; +- DTableH->fastMode = 0; +- +- cell->newState = 0; +- cell->symbol = symbolValue; +- cell->nbBits = 0; +- +- return 0; +-} +- +- +-size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits) +-{ +- void* ptr = dt; +- FSE_DTableHeader* const DTableH = (FSE_DTableHeader*)ptr; +- void* dPtr = dt + 1; +- FSE_decode_t* const dinfo = (FSE_decode_t*)dPtr; +- const unsigned tableSize = 1 << nbBits; +- const unsigned tableMask = tableSize - 1; +- const unsigned maxSV1 = tableMask+1; +- unsigned s; +- +- /* Sanity checks */ +- if (nbBits < 1) return ERROR(GENERIC); /* min size */ +- +- /* Build Decoding Table */ +- DTableH->tableLog = (U16)nbBits; +- DTableH->fastMode = 1; +- for (s=0; sfastMode; +- +- /* select fast mode (static) */ +- if (fastMode) return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 1); +- return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 0); +-} +- +- +-size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize) +-{ +- return FSE_decompress_wksp_bmi2(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize, /* bmi2 */ 0); ++ assert(op >= ostart); ++ return (size_t)(op-ostart); + } + + typedef struct { + short ncount[FSE_MAX_SYMBOL_VALUE + 1]; +- FSE_DTable dtable[]; /* Dynamically sized */ + } FSE_DecompressWksp; + + +@@ -327,13 +250,18 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( + unsigned tableLog; + unsigned maxSymbolValue = FSE_MAX_SYMBOL_VALUE; + FSE_DecompressWksp* const wksp = (FSE_DecompressWksp*)workSpace; ++ size_t const dtablePos = sizeof(FSE_DecompressWksp) / sizeof(FSE_DTable); ++ FSE_DTable* const dtable = (FSE_DTable*)workSpace + dtablePos; + +- DEBUG_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0); ++ FSE_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0); + if (wkspSize < sizeof(*wksp)) return ERROR(GENERIC); + ++ /* correct offset to dtable depends on this property */ ++ FSE_STATIC_ASSERT(sizeof(FSE_DecompressWksp) % sizeof(FSE_DTable) == 0); ++ + /* normal FSE decoding mode */ +- { +- size_t const NCountLength = FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); ++ { size_t const NCountLength = ++ FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); + if (FSE_isError(NCountLength)) return NCountLength; + if (tableLog > maxLog) return ERROR(tableLog_tooLarge); + assert(NCountLength <= cSrcSize); +@@ -342,19 +270,20 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( + } + + if (FSE_DECOMPRESS_WKSP_SIZE(tableLog, maxSymbolValue) > wkspSize) return ERROR(tableLog_tooLarge); +- workSpace = wksp->dtable + FSE_DTABLE_SIZE_U32(tableLog); ++ assert(sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog) <= wkspSize); ++ workSpace = (BYTE*)workSpace + sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog); + wkspSize -= sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog); + +- CHECK_F( FSE_buildDTable_internal(wksp->dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) ); ++ CHECK_F( FSE_buildDTable_internal(dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) ); + + { +- const void* ptr = wksp->dtable; ++ const void* ptr = dtable; + const FSE_DTableHeader* DTableH = (const FSE_DTableHeader*)ptr; + const U32 fastMode = DTableH->fastMode; + + /* select fast mode (static) */ +- if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 1); +- return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 0); ++ if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 1); ++ return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 0); + } + } + +@@ -382,9 +311,4 @@ size_t FSE_decompress_wksp_bmi2(void* dst, size_t dstCapacity, const void* cSrc, + return FSE_decompress_wksp_body_default(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize); + } + +- +-typedef FSE_DTable DTable_max_t[FSE_DTABLE_SIZE_U32(FSE_MAX_TABLELOG)]; +- +- +- + #endif /* FSE_COMMONDEFS_ONLY */ +diff --git a/lib/zstd/common/huf.h b/lib/zstd/common/huf.h +index 5042ff870308..57462466e188 100644 +--- a/lib/zstd/common/huf.h ++++ b/lib/zstd/common/huf.h +@@ -1,7 +1,8 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* ****************************************************************** + * huff0 huffman codec, + * part of Finite State Entropy library +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -18,99 +19,22 @@ + + /* *** Dependencies *** */ + #include "zstd_deps.h" /* size_t */ +- +- +-/* *** library symbols visibility *** */ +-/* Note : when linking with -fvisibility=hidden on gcc, or by default on Visual, +- * HUF symbols remain "private" (internal symbols for library only). +- * Set macro FSE_DLL_EXPORT to 1 if you want HUF symbols visible on DLL interface */ +-#if defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) && defined(__GNUC__) && (__GNUC__ >= 4) +-# define HUF_PUBLIC_API __attribute__ ((visibility ("default"))) +-#elif defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) /* Visual expected */ +-# define HUF_PUBLIC_API __declspec(dllexport) +-#elif defined(FSE_DLL_IMPORT) && (FSE_DLL_IMPORT==1) +-# define HUF_PUBLIC_API __declspec(dllimport) /* not required, just to generate faster code (saves a function pointer load from IAT and an indirect jump) */ +-#else +-# define HUF_PUBLIC_API +-#endif +- +- +-/* ========================== */ +-/* *** simple functions *** */ +-/* ========================== */ +- +-/* HUF_compress() : +- * Compress content from buffer 'src', of size 'srcSize', into buffer 'dst'. +- * 'dst' buffer must be already allocated. +- * Compression runs faster if `dstCapacity` >= HUF_compressBound(srcSize). +- * `srcSize` must be <= `HUF_BLOCKSIZE_MAX` == 128 KB. +- * @return : size of compressed data (<= `dstCapacity`). +- * Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!! +- * if HUF_isError(return), compression failed (more details using HUF_getErrorName()) +- */ +-HUF_PUBLIC_API size_t HUF_compress(void* dst, size_t dstCapacity, +- const void* src, size_t srcSize); +- +-/* HUF_decompress() : +- * Decompress HUF data from buffer 'cSrc', of size 'cSrcSize', +- * into already allocated buffer 'dst', of minimum size 'dstSize'. +- * `originalSize` : **must** be the ***exact*** size of original (uncompressed) data. +- * Note : in contrast with FSE, HUF_decompress can regenerate +- * RLE (cSrcSize==1) and uncompressed (cSrcSize==dstSize) data, +- * because it knows size to regenerate (originalSize). +- * @return : size of regenerated data (== originalSize), +- * or an error code, which can be tested using HUF_isError() +- */ +-HUF_PUBLIC_API size_t HUF_decompress(void* dst, size_t originalSize, +- const void* cSrc, size_t cSrcSize); ++#include "mem.h" /* U32 */ ++#define FSE_STATIC_LINKING_ONLY ++#include "fse.h" + + + /* *** Tool functions *** */ +-#define HUF_BLOCKSIZE_MAX (128 * 1024) /*< maximum input size for a single block compressed with HUF_compress */ +-HUF_PUBLIC_API size_t HUF_compressBound(size_t size); /*< maximum compressed size (worst case) */ ++#define HUF_BLOCKSIZE_MAX (128 * 1024) /*< maximum input size for a single block compressed with HUF_compress */ ++size_t HUF_compressBound(size_t size); /*< maximum compressed size (worst case) */ + + /* Error Management */ +-HUF_PUBLIC_API unsigned HUF_isError(size_t code); /*< tells if a return value is an error code */ +-HUF_PUBLIC_API const char* HUF_getErrorName(size_t code); /*< provides error code string (useful for debugging) */ +- ++unsigned HUF_isError(size_t code); /*< tells if a return value is an error code */ ++const char* HUF_getErrorName(size_t code); /*< provides error code string (useful for debugging) */ + +-/* *** Advanced function *** */ + +-/* HUF_compress2() : +- * Same as HUF_compress(), but offers control over `maxSymbolValue` and `tableLog`. +- * `maxSymbolValue` must be <= HUF_SYMBOLVALUE_MAX . +- * `tableLog` must be `<= HUF_TABLELOG_MAX` . */ +-HUF_PUBLIC_API size_t HUF_compress2 (void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, +- unsigned maxSymbolValue, unsigned tableLog); +- +-/* HUF_compress4X_wksp() : +- * Same as HUF_compress2(), but uses externally allocated `workSpace`. +- * `workspace` must be at least as large as HUF_WORKSPACE_SIZE */ + #define HUF_WORKSPACE_SIZE ((8 << 10) + 512 /* sorting scratch space */) + #define HUF_WORKSPACE_SIZE_U64 (HUF_WORKSPACE_SIZE / sizeof(U64)) +-HUF_PUBLIC_API size_t HUF_compress4X_wksp (void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, +- unsigned maxSymbolValue, unsigned tableLog, +- void* workSpace, size_t wkspSize); +- +-#endif /* HUF_H_298734234 */ +- +-/* ****************************************************************** +- * WARNING !! +- * The following section contains advanced and experimental definitions +- * which shall never be used in the context of a dynamic library, +- * because they are not guaranteed to remain stable in the future. +- * Only consider them in association with static linking. +- * *****************************************************************/ +-#if !defined(HUF_H_HUF_STATIC_LINKING_ONLY) +-#define HUF_H_HUF_STATIC_LINKING_ONLY +- +-/* *** Dependencies *** */ +-#include "mem.h" /* U32 */ +-#define FSE_STATIC_LINKING_ONLY +-#include "fse.h" +- + + /* *** Constants *** */ + #define HUF_TABLELOG_MAX 12 /* max runtime value of tableLog (due to static allocation); can be modified up to HUF_TABLELOG_ABSOLUTEMAX */ +@@ -151,25 +75,49 @@ typedef U32 HUF_DTable; + /* **************************************** + * Advanced decompression functions + ******************************************/ +-size_t HUF_decompress4X1 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ +-#ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_decompress4X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ +-#endif + +-size_t HUF_decompress4X_DCtx (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< decodes RLE and uncompressed */ +-size_t HUF_decompress4X_hufOnly(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< considers RLE and uncompressed as errors */ +-size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< considers RLE and uncompressed as errors */ +-size_t HUF_decompress4X1_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ +-size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< single-symbol decoder */ +-#ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_decompress4X2_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ +-size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< double-symbols decoder */ +-#endif ++/* ++ * Huffman flags bitset. ++ * For all flags, 0 is the default value. ++ */ ++typedef enum { ++ /* ++ * If compiled with DYNAMIC_BMI2: Set flag only if the CPU supports BMI2 at runtime. ++ * Otherwise: Ignored. ++ */ ++ HUF_flags_bmi2 = (1 << 0), ++ /* ++ * If set: Test possible table depths to find the one that produces the smallest header + encoded size. ++ * If unset: Use heuristic to find the table depth. ++ */ ++ HUF_flags_optimalDepth = (1 << 1), ++ /* ++ * If set: If the previous table can encode the input, always reuse the previous table. ++ * If unset: If the previous table can encode the input, reuse the previous table if it results in a smaller output. ++ */ ++ HUF_flags_preferRepeat = (1 << 2), ++ /* ++ * If set: Sample the input and check if the sample is uncompressible, if it is then don't attempt to compress. ++ * If unset: Always histogram the entire input. ++ */ ++ HUF_flags_suspectUncompressible = (1 << 3), ++ /* ++ * If set: Don't use assembly implementations ++ * If unset: Allow using assembly implementations ++ */ ++ HUF_flags_disableAsm = (1 << 4), ++ /* ++ * If set: Don't use the fast decoding loop, always use the fallback decoding loop. ++ * If unset: Use the fast decoding loop when possible. ++ */ ++ HUF_flags_disableFast = (1 << 5) ++} HUF_flags_e; + + + /* **************************************** + * HUF detailed API + * ****************************************/ ++#define HUF_OPTIMAL_DEPTH_THRESHOLD ZSTD_btultra + + /*! HUF_compress() does the following: + * 1. count symbol occurrence from source[] into table count[] using FSE_count() (exposed within "fse.h") +@@ -182,12 +130,12 @@ size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, + * For example, it's possible to compress several blocks using the same 'CTable', + * or to save and regenerate 'CTable' using external methods. + */ +-unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue); +-size_t HUF_buildCTable (HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue, unsigned maxNbBits); /* @return : maxNbBits; CTable and count can overlap. In which case, CTable will overwrite count content */ +-size_t HUF_writeCTable (void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog); ++unsigned HUF_minTableLog(unsigned symbolCardinality); ++unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue); ++unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, void* workSpace, ++ size_t wkspSize, HUF_CElt* table, const unsigned* count, int flags); /* table is used as scratch space for building and testing tables, not a return value */ + size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog, void* workspace, size_t workspaceSize); +-size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable); +-size_t HUF_compress4X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2); ++size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); + size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); + int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); + +@@ -196,6 +144,7 @@ typedef enum { + HUF_repeat_check, /*< Can use the previous table but it must be checked. Note : The previous table must have been constructed by HUF_compress{1, 4}X_repeat */ + HUF_repeat_valid /*< Can use the previous table and it is assumed to be valid */ + } HUF_repeat; ++ + /* HUF_compress4X_repeat() : + * Same as HUF_compress4X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. + * If it uses hufTable it does not modify hufTable or repeat. +@@ -206,13 +155,13 @@ size_t HUF_compress4X_repeat(void* dst, size_t dstSize, + const void* src, size_t srcSize, + unsigned maxSymbolValue, unsigned tableLog, + void* workSpace, size_t wkspSize, /*< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ +- HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible); ++ HUF_CElt* hufTable, HUF_repeat* repeat, int flags); + + /* HUF_buildCTable_wksp() : + * Same as HUF_buildCTable(), but using externally allocated scratch buffer. + * `workSpace` must be aligned on 4-bytes boundaries, and its size must be >= HUF_CTABLE_WORKSPACE_SIZE. + */ +-#define HUF_CTABLE_WORKSPACE_SIZE_U32 (2*HUF_SYMBOLVALUE_MAX +1 +1) ++#define HUF_CTABLE_WORKSPACE_SIZE_U32 ((4 * (HUF_SYMBOLVALUE_MAX + 1)) + 192) + #define HUF_CTABLE_WORKSPACE_SIZE (HUF_CTABLE_WORKSPACE_SIZE_U32 * sizeof(unsigned)) + size_t HUF_buildCTable_wksp (HUF_CElt* tree, + const unsigned* count, U32 maxSymbolValue, U32 maxNbBits, +@@ -238,7 +187,7 @@ size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, + U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, + const void* src, size_t srcSize, + void* workspace, size_t wkspSize, +- int bmi2); ++ int flags); + + /* HUF_readCTable() : + * Loading a CTable saved with HUF_writeCTable() */ +@@ -246,9 +195,22 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void + + /* HUF_getNbBitsFromCTable() : + * Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX +- * Note 1 : is not inlined, as HUF_CElt definition is private */ ++ * Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0 ++ * Note 2 : is not inlined, as HUF_CElt definition is private ++ */ + U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue); + ++typedef struct { ++ BYTE tableLog; ++ BYTE maxSymbolValue; ++ BYTE unused[sizeof(size_t) - 2]; ++} HUF_CTableHeader; ++ ++/* HUF_readCTableHeader() : ++ * @returns The header from the CTable specifying the tableLog and the maxSymbolValue. ++ */ ++HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable); ++ + /* + * HUF_decompress() does the following: + * 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics +@@ -276,32 +238,12 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize); + #define HUF_DECOMPRESS_WORKSPACE_SIZE ((2 << 10) + (1 << 9)) + #define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32)) + +-#ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_readDTableX1 (HUF_DTable* DTable, const void* src, size_t srcSize); +-size_t HUF_readDTableX1_wksp (HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize); +-#endif +-#ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_readDTableX2 (HUF_DTable* DTable, const void* src, size_t srcSize); +-size_t HUF_readDTableX2_wksp (HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize); +-#endif +- +-size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); +-#ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_decompress4X1_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); +-#endif +-#ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_decompress4X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); +-#endif +- + + /* ====================== */ + /* single stream variants */ + /* ====================== */ + +-size_t HUF_compress1X (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog); +-size_t HUF_compress1X_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); /*< `workSpace` must be a table of at least HUF_WORKSPACE_SIZE_U64 U64 */ +-size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable); +-size_t HUF_compress1X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2); ++size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); + /* HUF_compress1X_repeat() : + * Same as HUF_compress1X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. + * If it uses hufTable it does not modify hufTable or repeat. +@@ -312,47 +254,28 @@ size_t HUF_compress1X_repeat(void* dst, size_t dstSize, + const void* src, size_t srcSize, + unsigned maxSymbolValue, unsigned tableLog, + void* workSpace, size_t wkspSize, /*< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ +- HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible); ++ HUF_CElt* hufTable, HUF_repeat* repeat, int flags); + +-size_t HUF_decompress1X1 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /* single-symbol decoder */ +-#ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_decompress1X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /* double-symbol decoder */ +-#endif +- +-size_t HUF_decompress1X_DCtx (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); +-size_t HUF_decompress1X_DCtx_wksp (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); +-#ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_decompress1X1_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ +-size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< single-symbol decoder */ +-#endif ++size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); + #ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_decompress1X2_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ +-size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< double-symbols decoder */ +-#endif +- +-size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); /*< automatic selection of sing or double symbol decoder, based on DTable */ +-#ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_decompress1X1_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); +-#endif +-#ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_decompress1X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); ++size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); /*< double-symbols decoder */ + #endif + + /* BMI2 variants. + * If the CPU has BMI2 support, pass bmi2=1, otherwise pass bmi2=0. + */ +-size_t HUF_decompress1X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2); ++size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); + #ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_decompress1X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2); ++size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); + #endif +-size_t HUF_decompress4X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2); +-size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2); ++size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); ++size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); + #ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2); ++size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); + #endif + #ifndef HUF_FORCE_DECOMPRESS_X1 +-size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2); ++size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); + #endif + +-#endif /* HUF_STATIC_LINKING_ONLY */ ++#endif /* HUF_H_298734234 */ + +diff --git a/lib/zstd/common/mem.h b/lib/zstd/common/mem.h +index 1d9cc03924ca..2e91e7780c1f 100644 +--- a/lib/zstd/common/mem.h ++++ b/lib/zstd/common/mem.h +@@ -1,6 +1,6 @@ + /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -24,6 +24,7 @@ + /*-**************************************** + * Compiler specifics + ******************************************/ ++#undef MEM_STATIC /* may be already defined from common/compiler.h */ + #define MEM_STATIC static inline + + /*-************************************************************** +diff --git a/lib/zstd/common/portability_macros.h b/lib/zstd/common/portability_macros.h +index 0e3b2c0a527d..f08638cced6c 100644 +--- a/lib/zstd/common/portability_macros.h ++++ b/lib/zstd/common/portability_macros.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -12,7 +13,7 @@ + #define ZSTD_PORTABILITY_MACROS_H + + /* +- * This header file contains macro defintions to support portability. ++ * This header file contains macro definitions to support portability. + * This header is shared between C and ASM code, so it MUST only + * contain macro definitions. It MUST not contain any C code. + * +@@ -45,6 +46,8 @@ + /* Mark the internal assembly functions as hidden */ + #ifdef __ELF__ + # define ZSTD_HIDE_ASM_FUNCTION(func) .hidden func ++#elif defined(__APPLE__) ++# define ZSTD_HIDE_ASM_FUNCTION(func) .private_extern func + #else + # define ZSTD_HIDE_ASM_FUNCTION(func) + #endif +@@ -65,7 +68,7 @@ + #endif + + /* +- * Only enable assembly for GNUC comptabile compilers, ++ * Only enable assembly for GNUC compatible compilers, + * because other platforms may not support GAS assembly syntax. + * + * Only enable assembly for Linux / MacOS, other platforms may +@@ -90,4 +93,23 @@ + */ + #define ZSTD_ENABLE_ASM_X86_64_BMI2 0 + ++/* ++ * For x86 ELF targets, add .note.gnu.property section for Intel CET in ++ * assembly sources when CET is enabled. ++ * ++ * Additionally, any function that may be called indirectly must begin ++ * with ZSTD_CET_ENDBRANCH. ++ */ ++#if defined(__ELF__) && (defined(__x86_64__) || defined(__i386__)) \ ++ && defined(__has_include) ++# if __has_include() ++# include ++# define ZSTD_CET_ENDBRANCH _CET_ENDBR ++# endif ++#endif ++ ++#ifndef ZSTD_CET_ENDBRANCH ++# define ZSTD_CET_ENDBRANCH ++#endif ++ + #endif /* ZSTD_PORTABILITY_MACROS_H */ +diff --git a/lib/zstd/common/zstd_common.c b/lib/zstd/common/zstd_common.c +index 3d7e35b309b5..44b95b25344a 100644 +--- a/lib/zstd/common/zstd_common.c ++++ b/lib/zstd/common/zstd_common.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -14,7 +15,6 @@ + * Dependencies + ***************************************/ + #define ZSTD_DEPS_NEED_MALLOC +-#include "zstd_deps.h" /* ZSTD_malloc, ZSTD_calloc, ZSTD_free, ZSTD_memset */ + #include "error_private.h" + #include "zstd_internal.h" + +@@ -47,37 +47,3 @@ ZSTD_ErrorCode ZSTD_getErrorCode(size_t code) { return ERR_getErrorCode(code); } + /*! ZSTD_getErrorString() : + * provides error code string from enum */ + const char* ZSTD_getErrorString(ZSTD_ErrorCode code) { return ERR_getErrorString(code); } +- +- +- +-/*=************************************************************** +-* Custom allocator +-****************************************************************/ +-void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) +-{ +- if (customMem.customAlloc) +- return customMem.customAlloc(customMem.opaque, size); +- return ZSTD_malloc(size); +-} +- +-void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) +-{ +- if (customMem.customAlloc) { +- /* calloc implemented as malloc+memset; +- * not as efficient as calloc, but next best guess for custom malloc */ +- void* const ptr = customMem.customAlloc(customMem.opaque, size); +- ZSTD_memset(ptr, 0, size); +- return ptr; +- } +- return ZSTD_calloc(1, size); +-} +- +-void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) +-{ +- if (ptr!=NULL) { +- if (customMem.customFree) +- customMem.customFree(customMem.opaque, ptr); +- else +- ZSTD_free(ptr); +- } +-} +diff --git a/lib/zstd/common/zstd_deps.h b/lib/zstd/common/zstd_deps.h +index 2c34e8a33a1c..f931f7d0e294 100644 +--- a/lib/zstd/common/zstd_deps.h ++++ b/lib/zstd/common/zstd_deps.h +@@ -1,6 +1,6 @@ + /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -105,3 +105,17 @@ static uint64_t ZSTD_div64(uint64_t dividend, uint32_t divisor) { + + #endif /* ZSTD_DEPS_IO */ + #endif /* ZSTD_DEPS_NEED_IO */ ++ ++/* ++ * Only requested when MSAN is enabled. ++ * Need: ++ * intptr_t ++ */ ++#ifdef ZSTD_DEPS_NEED_STDINT ++#ifndef ZSTD_DEPS_STDINT ++#define ZSTD_DEPS_STDINT ++ ++/* intptr_t already provided by ZSTD_DEPS_COMMON */ ++ ++#endif /* ZSTD_DEPS_STDINT */ ++#endif /* ZSTD_DEPS_NEED_STDINT */ +diff --git a/lib/zstd/common/zstd_internal.h b/lib/zstd/common/zstd_internal.h +index 93305d9b41bb..11da1233e890 100644 +--- a/lib/zstd/common/zstd_internal.h ++++ b/lib/zstd/common/zstd_internal.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -28,7 +29,6 @@ + #include + #define FSE_STATIC_LINKING_ONLY + #include "fse.h" +-#define HUF_STATIC_LINKING_ONLY + #include "huf.h" + #include /* XXH_reset, update, digest */ + #define ZSTD_TRACE 0 +@@ -83,9 +83,9 @@ typedef enum { bt_raw, bt_rle, bt_compressed, bt_reserved } blockType_e; + #define ZSTD_FRAMECHECKSUMSIZE 4 + + #define MIN_SEQUENCES_SIZE 1 /* nbSeq==0 */ +-#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */ + MIN_SEQUENCES_SIZE /* nbSeq==0 */) /* for a non-null block */ ++#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */) /* for a non-null block */ ++#define MIN_LITERALS_FOR_4_STREAMS 6 + +-#define HufLog 12 + typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingType_e; + + #define LONGNBSEQ 0x7F00 +@@ -93,6 +93,7 @@ typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingTy + #define MINMATCH 3 + + #define Litbits 8 ++#define LitHufLog 11 + #define MaxLit ((1<= WILDCOPY_VECLEN || diff <= -WILDCOPY_VECLEN); +@@ -225,12 +228,6 @@ void ZSTD_wildcopy(void* dst, const void* src, ptrdiff_t length, ZSTD_overlap_e + * one COPY16() in the first call. Then, do two calls per loop since + * at that point it is more likely to have a high trip count. + */ +-#ifdef __aarch64__ +- do { +- COPY16(op, ip); +- } +- while (op < oend); +-#else + ZSTD_copy16(op, ip); + if (16 >= length) return; + op += 16; +@@ -240,7 +237,6 @@ void ZSTD_wildcopy(void* dst, const void* src, ptrdiff_t length, ZSTD_overlap_e + COPY16(op, ip); + } + while (op < oend); +-#endif + } + } + +@@ -289,11 +285,11 @@ typedef enum { + typedef struct { + seqDef* sequencesStart; + seqDef* sequences; /* ptr to end of sequences */ +- BYTE* litStart; +- BYTE* lit; /* ptr to end of literals */ +- BYTE* llCode; +- BYTE* mlCode; +- BYTE* ofCode; ++ BYTE* litStart; ++ BYTE* lit; /* ptr to end of literals */ ++ BYTE* llCode; ++ BYTE* mlCode; ++ BYTE* ofCode; + size_t maxNbSeq; + size_t maxNbLit; + +@@ -301,8 +297,8 @@ typedef struct { + * in the seqStore that has a value larger than U16 (if it exists). To do so, we increment + * the existing value of the litLength or matchLength by 0x10000. + */ +- ZSTD_longLengthType_e longLengthType; +- U32 longLengthPos; /* Index of the sequence to apply long length modification to */ ++ ZSTD_longLengthType_e longLengthType; ++ U32 longLengthPos; /* Index of the sequence to apply long length modification to */ + } seqStore_t; + + typedef struct { +@@ -321,10 +317,10 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore + seqLen.matchLength = seq->mlBase + MINMATCH; + if (seqStore->longLengthPos == (U32)(seq - seqStore->sequencesStart)) { + if (seqStore->longLengthType == ZSTD_llt_literalLength) { +- seqLen.litLength += 0xFFFF; ++ seqLen.litLength += 0x10000; + } + if (seqStore->longLengthType == ZSTD_llt_matchLength) { +- seqLen.matchLength += 0xFFFF; ++ seqLen.matchLength += 0x10000; + } + } + return seqLen; +@@ -337,72 +333,13 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore + * `decompressedBound != ZSTD_CONTENTSIZE_ERROR` + */ + typedef struct { ++ size_t nbBlocks; + size_t compressedSize; + unsigned long long decompressedBound; + } ZSTD_frameSizeInfo; /* decompress & legacy */ + + const seqStore_t* ZSTD_getSeqStore(const ZSTD_CCtx* ctx); /* compress & dictBuilder */ +-void ZSTD_seqToCodes(const seqStore_t* seqStorePtr); /* compress, dictBuilder, decodeCorpus (shouldn't get its definition from here) */ +- +-/* custom memory allocation functions */ +-void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem); +-void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem); +-void ZSTD_customFree(void* ptr, ZSTD_customMem customMem); +- +- +-MEM_STATIC U32 ZSTD_highbit32(U32 val) /* compress, dictBuilder, decodeCorpus */ +-{ +- assert(val != 0); +- { +-# if (__GNUC__ >= 3) /* GCC Intrinsic */ +- return __builtin_clz (val) ^ 31; +-# else /* Software version */ +- static const U32 DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31 }; +- U32 v = val; +- v |= v >> 1; +- v |= v >> 2; +- v |= v >> 4; +- v |= v >> 8; +- v |= v >> 16; +- return DeBruijnClz[(v * 0x07C4ACDDU) >> 27]; +-# endif +- } +-} +- +-/* +- * Counts the number of trailing zeros of a `size_t`. +- * Most compilers should support CTZ as a builtin. A backup +- * implementation is provided if the builtin isn't supported, but +- * it may not be terribly efficient. +- */ +-MEM_STATIC unsigned ZSTD_countTrailingZeros(size_t val) +-{ +- if (MEM_64bits()) { +-# if (__GNUC__ >= 4) +- return __builtin_ctzll((U64)val); +-# else +- static const int DeBruijnBytePos[64] = { 0, 1, 2, 7, 3, 13, 8, 19, +- 4, 25, 14, 28, 9, 34, 20, 56, +- 5, 17, 26, 54, 15, 41, 29, 43, +- 10, 31, 38, 35, 21, 45, 49, 57, +- 63, 6, 12, 18, 24, 27, 33, 55, +- 16, 53, 40, 42, 30, 37, 44, 48, +- 62, 11, 23, 32, 52, 39, 36, 47, +- 61, 22, 51, 46, 60, 50, 59, 58 }; +- return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; +-# endif +- } else { /* 32 bits */ +-# if (__GNUC__ >= 3) +- return __builtin_ctz((U32)val); +-# else +- static const int DeBruijnBytePos[32] = { 0, 1, 28, 2, 29, 14, 24, 3, +- 30, 22, 20, 15, 25, 17, 4, 8, +- 31, 27, 13, 23, 21, 19, 16, 7, +- 26, 12, 18, 6, 11, 5, 10, 9 }; +- return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; +-# endif +- } +-} ++int ZSTD_seqToCodes(const seqStore_t* seqStorePtr); /* compress, dictBuilder, decodeCorpus (shouldn't get its definition from here) */ + + + /* ZSTD_invalidateRepCodes() : +@@ -420,13 +357,13 @@ typedef struct { + + /*! ZSTD_getcBlockSize() : + * Provides the size of compressed block from block header `src` */ +-/* Used by: decompress, fullbench (does not get its definition from here) */ ++/* Used by: decompress, fullbench */ + size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, + blockProperties_t* bpPtr); + + /*! ZSTD_decodeSeqHeaders() : + * decode sequence header from src */ +-/* Used by: decompress, fullbench (does not get its definition from here) */ ++/* Used by: zstd_decompress_block, fullbench */ + size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, + const void* src, size_t srcSize); + +diff --git a/lib/zstd/compress/clevels.h b/lib/zstd/compress/clevels.h +index d9a76112ec3a..6ab8be6532ef 100644 +--- a/lib/zstd/compress/clevels.h ++++ b/lib/zstd/compress/clevels.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/compress/fse_compress.c b/lib/zstd/compress/fse_compress.c +index ec5b1ca6d71a..44a3c10becf2 100644 +--- a/lib/zstd/compress/fse_compress.c ++++ b/lib/zstd/compress/fse_compress.c +@@ -1,6 +1,7 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* ****************************************************************** + * FSE : Finite State Entropy encoder +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -25,7 +26,8 @@ + #include "../common/error_private.h" + #define ZSTD_DEPS_NEED_MALLOC + #define ZSTD_DEPS_NEED_MATH64 +-#include "../common/zstd_deps.h" /* ZSTD_malloc, ZSTD_free, ZSTD_memcpy, ZSTD_memset */ ++#include "../common/zstd_deps.h" /* ZSTD_memset */ ++#include "../common/bits.h" /* ZSTD_highbit32 */ + + + /* ************************************************************** +@@ -90,7 +92,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, + assert(tableLog < 16); /* required for threshold strategy to work */ + + /* For explanations on how to distribute symbol values over the table : +- * http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ ++ * https://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ + + #ifdef __clang_analyzer__ + ZSTD_memset(tableSymbol, 0, sizeof(*tableSymbol) * tableSize); /* useless initialization, just to keep scan-build happy */ +@@ -191,7 +193,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, + break; + default : + assert(normalizedCounter[s] > 1); +- { U32 const maxBitsOut = tableLog - BIT_highbit32 ((U32)normalizedCounter[s]-1); ++ { U32 const maxBitsOut = tableLog - ZSTD_highbit32 ((U32)normalizedCounter[s]-1); + U32 const minStatePlus = (U32)normalizedCounter[s] << maxBitsOut; + symbolTT[s].deltaNbBits = (maxBitsOut << 16) - minStatePlus; + symbolTT[s].deltaFindState = (int)(total - (unsigned)normalizedCounter[s]); +@@ -224,8 +226,8 @@ size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog) + size_t const maxHeaderSize = (((maxSymbolValue+1) * tableLog + + 4 /* bitCount initialized at 4 */ + + 2 /* first two symbols may use one additional bit each */) / 8) +- + 1 /* round up to whole nb bytes */ +- + 2 /* additional two bytes for bitstream flush */; ++ + 1 /* round up to whole nb bytes */ ++ + 2 /* additional two bytes for bitstream flush */; + return maxSymbolValue ? maxHeaderSize : FSE_NCOUNTBOUND; /* maxSymbolValue==0 ? use default */ + } + +@@ -254,7 +256,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, + /* Init */ + remaining = tableSize+1; /* +1 for extra accuracy */ + threshold = tableSize; +- nbBits = tableLog+1; ++ nbBits = (int)tableLog+1; + + while ((symbol < alphabetSize) && (remaining>1)) { /* stops at 1 */ + if (previousIs0) { +@@ -273,7 +275,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, + } + while (symbol >= start+3) { + start+=3; +- bitStream += 3 << bitCount; ++ bitStream += 3U << bitCount; + bitCount += 2; + } + bitStream += (symbol-start) << bitCount; +@@ -293,7 +295,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, + count++; /* +1 for extra accuracy */ + if (count>=threshold) + count += max; /* [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ */ +- bitStream += count << bitCount; ++ bitStream += (U32)count << bitCount; + bitCount += nbBits; + bitCount -= (count>8); + out+= (bitCount+7) /8; + +- return (out-ostart); ++ assert(out >= ostart); ++ return (size_t)(out-ostart); + } + + +@@ -342,21 +345,11 @@ size_t FSE_writeNCount (void* buffer, size_t bufferSize, + * FSE Compression Code + ****************************************************************/ + +-FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog) +-{ +- size_t size; +- if (tableLog > FSE_TABLELOG_ABSOLUTE_MAX) tableLog = FSE_TABLELOG_ABSOLUTE_MAX; +- size = FSE_CTABLE_SIZE_U32 (tableLog, maxSymbolValue) * sizeof(U32); +- return (FSE_CTable*)ZSTD_malloc(size); +-} +- +-void FSE_freeCTable (FSE_CTable* ct) { ZSTD_free(ct); } +- + /* provides the minimum logSize to safely represent a distribution */ + static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) + { +- U32 minBitsSrc = BIT_highbit32((U32)(srcSize)) + 1; +- U32 minBitsSymbols = BIT_highbit32(maxSymbolValue) + 2; ++ U32 minBitsSrc = ZSTD_highbit32((U32)(srcSize)) + 1; ++ U32 minBitsSymbols = ZSTD_highbit32(maxSymbolValue) + 2; + U32 minBits = minBitsSrc < minBitsSymbols ? minBitsSrc : minBitsSymbols; + assert(srcSize > 1); /* Not supported, RLE should be used instead */ + return minBits; +@@ -364,7 +357,7 @@ static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) + + unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus) + { +- U32 maxBitsSrc = BIT_highbit32((U32)(srcSize - 1)) - minus; ++ U32 maxBitsSrc = ZSTD_highbit32((U32)(srcSize - 1)) - minus; + U32 tableLog = maxTableLog; + U32 minBits = FSE_minTableLog(srcSize, maxSymbolValue); + assert(srcSize > 1); /* Not supported, RLE should be used instead */ +@@ -532,40 +525,6 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, + return tableLog; + } + +- +-/* fake FSE_CTable, for raw (uncompressed) input */ +-size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits) +-{ +- const unsigned tableSize = 1 << nbBits; +- const unsigned tableMask = tableSize - 1; +- const unsigned maxSymbolValue = tableMask; +- void* const ptr = ct; +- U16* const tableU16 = ( (U16*) ptr) + 2; +- void* const FSCT = ((U32*)ptr) + 1 /* header */ + (tableSize>>1); /* assumption : tableLog >= 1 */ +- FSE_symbolCompressionTransform* const symbolTT = (FSE_symbolCompressionTransform*) (FSCT); +- unsigned s; +- +- /* Sanity checks */ +- if (nbBits < 1) return ERROR(GENERIC); /* min size */ +- +- /* header */ +- tableU16[-2] = (U16) nbBits; +- tableU16[-1] = (U16) maxSymbolValue; +- +- /* Build table */ +- for (s=0; s= 2 ++ ++static size_t showU32(const U32* arr, size_t size) + { +- return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); ++ size_t u; ++ for (u=0; u= sizeof(HUF_WriteCTableWksp)); ++ ++ assert(HUF_readCTableHeader(CTable).maxSymbolValue == maxSymbolValue); ++ assert(HUF_readCTableHeader(CTable).tableLog == huffLog); ++ + /* check conditions */ + if (workspaceSize < sizeof(HUF_WriteCTableWksp)) return ERROR(GENERIC); + if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) return ERROR(maxSymbolValue_tooLarge); +@@ -204,16 +286,6 @@ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, + return ((maxSymbolValue+1)/2) + 1; + } + +-/*! HUF_writeCTable() : +- `CTable` : Huffman tree to save, using huf representation. +- @return : size of saved CTable */ +-size_t HUF_writeCTable (void* dst, size_t maxDstSize, +- const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog) +-{ +- HUF_WriteCTableWksp wksp; +- return HUF_writeCTable_wksp(dst, maxDstSize, CTable, maxSymbolValue, huffLog, &wksp, sizeof(wksp)); +-} +- + + size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned* hasZeroWeights) + { +@@ -231,7 +303,9 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void + if (tableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); + if (nbSymbols > *maxSymbolValuePtr+1) return ERROR(maxSymbolValue_tooSmall); + +- CTable[0] = tableLog; ++ *maxSymbolValuePtr = nbSymbols - 1; ++ ++ HUF_writeCTableHeader(CTable, tableLog, *maxSymbolValuePtr); + + /* Prepare base value per rank */ + { U32 n, nextRankStart = 0; +@@ -263,74 +337,71 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void + { U32 n; for (n=0; n HUF_readCTableHeader(CTable).maxSymbolValue) ++ return 0; + return (U32)HUF_getNbBits(ct[symbolValue]); + } + + +-typedef struct nodeElt_s { +- U32 count; +- U16 parent; +- BYTE byte; +- BYTE nbBits; +-} nodeElt; +- + /* + * HUF_setMaxHeight(): +- * Enforces maxNbBits on the Huffman tree described in huffNode. ++ * Try to enforce @targetNbBits on the Huffman tree described in @huffNode. + * +- * It sets all nodes with nbBits > maxNbBits to be maxNbBits. Then it adjusts +- * the tree to so that it is a valid canonical Huffman tree. ++ * It attempts to convert all nodes with nbBits > @targetNbBits ++ * to employ @targetNbBits instead. Then it adjusts the tree ++ * so that it remains a valid canonical Huffman tree. + * + * @pre The sum of the ranks of each symbol == 2^largestBits, + * where largestBits == huffNode[lastNonNull].nbBits. + * @post The sum of the ranks of each symbol == 2^largestBits, +- * where largestBits is the return value <= maxNbBits. ++ * where largestBits is the return value (expected <= targetNbBits). + * +- * @param huffNode The Huffman tree modified in place to enforce maxNbBits. ++ * @param huffNode The Huffman tree modified in place to enforce targetNbBits. ++ * It's presumed sorted, from most frequent to rarest symbol. + * @param lastNonNull The symbol with the lowest count in the Huffman tree. +- * @param maxNbBits The maximum allowed number of bits, which the Huffman tree ++ * @param targetNbBits The allowed number of bits, which the Huffman tree + * may not respect. After this function the Huffman tree will +- * respect maxNbBits. +- * @return The maximum number of bits of the Huffman tree after adjustment, +- * necessarily no more than maxNbBits. ++ * respect targetNbBits. ++ * @return The maximum number of bits of the Huffman tree after adjustment. + */ +-static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) ++static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 targetNbBits) + { + const U32 largestBits = huffNode[lastNonNull].nbBits; +- /* early exit : no elt > maxNbBits, so the tree is already valid. */ +- if (largestBits <= maxNbBits) return largestBits; ++ /* early exit : no elt > targetNbBits, so the tree is already valid. */ ++ if (largestBits <= targetNbBits) return largestBits; ++ ++ DEBUGLOG(5, "HUF_setMaxHeight (targetNbBits = %u)", targetNbBits); + + /* there are several too large elements (at least >= 2) */ + { int totalCost = 0; +- const U32 baseCost = 1 << (largestBits - maxNbBits); ++ const U32 baseCost = 1 << (largestBits - targetNbBits); + int n = (int)lastNonNull; + +- /* Adjust any ranks > maxNbBits to maxNbBits. ++ /* Adjust any ranks > targetNbBits to targetNbBits. + * Compute totalCost, which is how far the sum of the ranks is + * we are over 2^largestBits after adjust the offending ranks. + */ +- while (huffNode[n].nbBits > maxNbBits) { ++ while (huffNode[n].nbBits > targetNbBits) { + totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)); +- huffNode[n].nbBits = (BYTE)maxNbBits; ++ huffNode[n].nbBits = (BYTE)targetNbBits; + n--; + } +- /* n stops at huffNode[n].nbBits <= maxNbBits */ +- assert(huffNode[n].nbBits <= maxNbBits); +- /* n end at index of smallest symbol using < maxNbBits */ +- while (huffNode[n].nbBits == maxNbBits) --n; ++ /* n stops at huffNode[n].nbBits <= targetNbBits */ ++ assert(huffNode[n].nbBits <= targetNbBits); ++ /* n end at index of smallest symbol using < targetNbBits */ ++ while (huffNode[n].nbBits == targetNbBits) --n; + +- /* renorm totalCost from 2^largestBits to 2^maxNbBits ++ /* renorm totalCost from 2^largestBits to 2^targetNbBits + * note : totalCost is necessarily a multiple of baseCost */ +- assert((totalCost & (baseCost - 1)) == 0); +- totalCost >>= (largestBits - maxNbBits); ++ assert(((U32)totalCost & (baseCost - 1)) == 0); ++ totalCost >>= (largestBits - targetNbBits); + assert(totalCost > 0); + + /* repay normalized cost */ +@@ -339,19 +410,19 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) + + /* Get pos of last (smallest = lowest cum. count) symbol per rank */ + ZSTD_memset(rankLast, 0xF0, sizeof(rankLast)); +- { U32 currentNbBits = maxNbBits; ++ { U32 currentNbBits = targetNbBits; + int pos; + for (pos=n ; pos >= 0; pos--) { + if (huffNode[pos].nbBits >= currentNbBits) continue; +- currentNbBits = huffNode[pos].nbBits; /* < maxNbBits */ +- rankLast[maxNbBits-currentNbBits] = (U32)pos; ++ currentNbBits = huffNode[pos].nbBits; /* < targetNbBits */ ++ rankLast[targetNbBits-currentNbBits] = (U32)pos; + } } + + while (totalCost > 0) { + /* Try to reduce the next power of 2 above totalCost because we + * gain back half the rank. + */ +- U32 nBitsToDecrease = BIT_highbit32((U32)totalCost) + 1; ++ U32 nBitsToDecrease = ZSTD_highbit32((U32)totalCost) + 1; + for ( ; nBitsToDecrease > 1; nBitsToDecrease--) { + U32 const highPos = rankLast[nBitsToDecrease]; + U32 const lowPos = rankLast[nBitsToDecrease-1]; +@@ -391,7 +462,7 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) + rankLast[nBitsToDecrease] = noSymbol; + else { + rankLast[nBitsToDecrease]--; +- if (huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease) ++ if (huffNode[rankLast[nBitsToDecrease]].nbBits != targetNbBits-nBitsToDecrease) + rankLast[nBitsToDecrease] = noSymbol; /* this rank is now empty */ + } + } /* while (totalCost > 0) */ +@@ -403,11 +474,11 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) + * TODO. + */ + while (totalCost < 0) { /* Sometimes, cost correction overshoot */ +- /* special case : no rank 1 symbol (using maxNbBits-1); +- * let's create one from largest rank 0 (using maxNbBits). ++ /* special case : no rank 1 symbol (using targetNbBits-1); ++ * let's create one from largest rank 0 (using targetNbBits). + */ + if (rankLast[1] == noSymbol) { +- while (huffNode[n].nbBits == maxNbBits) n--; ++ while (huffNode[n].nbBits == targetNbBits) n--; + huffNode[n+1].nbBits--; + assert(n >= 0); + rankLast[1] = (U32)(n+1); +@@ -421,7 +492,7 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) + } /* repay normalized cost */ + } /* there are several too large elements (at least >= 2) */ + +- return maxNbBits; ++ return targetNbBits; + } + + typedef struct { +@@ -429,7 +500,7 @@ typedef struct { + U16 curr; + } rankPos; + +-typedef nodeElt huffNodeTable[HUF_CTABLE_WORKSPACE_SIZE_U32]; ++typedef nodeElt huffNodeTable[2 * (HUF_SYMBOLVALUE_MAX + 1)]; + + /* Number of buckets available for HUF_sort() */ + #define RANK_POSITION_TABLE_SIZE 192 +@@ -448,8 +519,8 @@ typedef struct { + * Let buckets 166 to 192 represent all remaining counts up to RANK_POSITION_MAX_COUNT_LOG using log2 bucketing. + */ + #define RANK_POSITION_MAX_COUNT_LOG 32 +-#define RANK_POSITION_LOG_BUCKETS_BEGIN (RANK_POSITION_TABLE_SIZE - 1) - RANK_POSITION_MAX_COUNT_LOG - 1 /* == 158 */ +-#define RANK_POSITION_DISTINCT_COUNT_CUTOFF RANK_POSITION_LOG_BUCKETS_BEGIN + BIT_highbit32(RANK_POSITION_LOG_BUCKETS_BEGIN) /* == 166 */ ++#define RANK_POSITION_LOG_BUCKETS_BEGIN ((RANK_POSITION_TABLE_SIZE - 1) - RANK_POSITION_MAX_COUNT_LOG - 1 /* == 158 */) ++#define RANK_POSITION_DISTINCT_COUNT_CUTOFF (RANK_POSITION_LOG_BUCKETS_BEGIN + ZSTD_highbit32(RANK_POSITION_LOG_BUCKETS_BEGIN) /* == 166 */) + + /* Return the appropriate bucket index for a given count. See definition of + * RANK_POSITION_DISTINCT_COUNT_CUTOFF for explanation of bucketing strategy. +@@ -457,7 +528,7 @@ typedef struct { + static U32 HUF_getIndex(U32 const count) { + return (count < RANK_POSITION_DISTINCT_COUNT_CUTOFF) + ? count +- : BIT_highbit32(count) + RANK_POSITION_LOG_BUCKETS_BEGIN; ++ : ZSTD_highbit32(count) + RANK_POSITION_LOG_BUCKETS_BEGIN; + } + + /* Helper swap function for HUF_quickSortPartition() */ +@@ -580,7 +651,7 @@ static void HUF_sort(nodeElt huffNode[], const unsigned count[], U32 const maxSy + + /* Sort each bucket. */ + for (n = RANK_POSITION_DISTINCT_COUNT_CUTOFF; n < RANK_POSITION_TABLE_SIZE - 1; ++n) { +- U32 const bucketSize = rankPosition[n].curr-rankPosition[n].base; ++ int const bucketSize = rankPosition[n].curr - rankPosition[n].base; + U32 const bucketStartIdx = rankPosition[n].base; + if (bucketSize > 1) { + assert(bucketStartIdx < maxSymbolValue1); +@@ -591,6 +662,7 @@ static void HUF_sort(nodeElt huffNode[], const unsigned count[], U32 const maxSy + assert(HUF_isSorted(huffNode, maxSymbolValue1)); + } + ++ + /* HUF_buildCTable_wksp() : + * Same as HUF_buildCTable(), but using externally allocated scratch buffer. + * `workSpace` must be aligned on 4-bytes boundaries, and be at least as large as sizeof(HUF_buildCTable_wksp_tables). +@@ -611,6 +683,7 @@ static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) + int lowS, lowN; + int nodeNb = STARTNODE; + int n, nodeRoot; ++ DEBUGLOG(5, "HUF_buildTree (alphabet size = %u)", maxSymbolValue + 1); + /* init for parents */ + nonNullRank = (int)maxSymbolValue; + while(huffNode[nonNullRank].count == 0) nonNullRank--; +@@ -637,6 +710,8 @@ static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) + for (n=0; n<=nonNullRank; n++) + huffNode[n].nbBits = huffNode[ huffNode[n].parent ].nbBits + 1; + ++ DEBUGLOG(6, "Initial distribution of bits completed (%zu sorted symbols)", showHNodeBits(huffNode, maxSymbolValue+1)); ++ + return nonNullRank; + } + +@@ -671,31 +746,40 @@ static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, i + HUF_setNbBits(ct + huffNode[n].byte, huffNode[n].nbBits); /* push nbBits per symbol, symbol order */ + for (n=0; nhuffNodeTbl; + nodeElt* const huffNode = huffNode0+1; + int nonNullRank; + ++ HUF_STATIC_ASSERT(HUF_CTABLE_WORKSPACE_SIZE == sizeof(HUF_buildCTable_wksp_tables)); ++ ++ DEBUGLOG(5, "HUF_buildCTable_wksp (alphabet size = %u)", maxSymbolValue+1); ++ + /* safety checks */ + if (wkspSize < sizeof(HUF_buildCTable_wksp_tables)) +- return ERROR(workSpace_tooSmall); ++ return ERROR(workSpace_tooSmall); + if (maxNbBits == 0) maxNbBits = HUF_TABLELOG_DEFAULT; + if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) +- return ERROR(maxSymbolValue_tooLarge); ++ return ERROR(maxSymbolValue_tooLarge); + ZSTD_memset(huffNode0, 0, sizeof(huffNodeTable)); + + /* sort, decreasing order */ + HUF_sort(huffNode, count, maxSymbolValue, wksp_tables->rankPosition); ++ DEBUGLOG(6, "sorted symbols completed (%zu symbols)", showHNodeSymbols(huffNode, maxSymbolValue+1)); + + /* build tree */ + nonNullRank = HUF_buildTree(huffNode, maxSymbolValue); + +- /* enforce maxTableLog */ ++ /* determine and enforce maxTableLog */ + maxNbBits = HUF_setMaxHeight(huffNode, (U32)nonNullRank, maxNbBits); + if (maxNbBits > HUF_TABLELOG_MAX) return ERROR(GENERIC); /* check fit into table */ + +@@ -716,13 +800,20 @@ size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, + } + + int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue) { +- HUF_CElt const* ct = CTable + 1; +- int bad = 0; +- int s; +- for (s = 0; s <= (int)maxSymbolValue; ++s) { +- bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0); +- } +- return !bad; ++ HUF_CTableHeader header = HUF_readCTableHeader(CTable); ++ HUF_CElt const* ct = CTable + 1; ++ int bad = 0; ++ int s; ++ ++ assert(header.tableLog <= HUF_TABLELOG_ABSOLUTEMAX); ++ ++ if (header.maxSymbolValue < maxSymbolValue) ++ return 0; ++ ++ for (s = 0; s <= (int)maxSymbolValue; ++s) { ++ bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0); ++ } ++ return !bad; + } + + size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); } +@@ -804,7 +895,7 @@ FORCE_INLINE_TEMPLATE void HUF_addBits(HUF_CStream_t* bitC, HUF_CElt elt, int id + #if DEBUGLEVEL >= 1 + { + size_t const nbBits = HUF_getNbBits(elt); +- size_t const dirtyBits = nbBits == 0 ? 0 : BIT_highbit32((U32)nbBits) + 1; ++ size_t const dirtyBits = nbBits == 0 ? 0 : ZSTD_highbit32((U32)nbBits) + 1; + (void)dirtyBits; + /* Middle bits are 0. */ + assert(((elt >> dirtyBits) << (dirtyBits + nbBits)) == 0); +@@ -884,7 +975,7 @@ static size_t HUF_closeCStream(HUF_CStream_t* bitC) + { + size_t const nbBits = bitC->bitPos[0] & 0xFF; + if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */ +- return (bitC->ptr - bitC->startPtr) + (nbBits > 0); ++ return (size_t)(bitC->ptr - bitC->startPtr) + (nbBits > 0); + } + } + +@@ -964,17 +1055,17 @@ HUF_compress1X_usingCTable_internal_body(void* dst, size_t dstSize, + const void* src, size_t srcSize, + const HUF_CElt* CTable) + { +- U32 const tableLog = (U32)CTable[0]; ++ U32 const tableLog = HUF_readCTableHeader(CTable).tableLog; + HUF_CElt const* ct = CTable + 1; + const BYTE* ip = (const BYTE*) src; + BYTE* const ostart = (BYTE*)dst; + BYTE* const oend = ostart + dstSize; +- BYTE* op = ostart; + HUF_CStream_t bitC; + + /* init */ + if (dstSize < 8) return 0; /* not enough space to compress */ +- { size_t const initErr = HUF_initCStream(&bitC, op, (size_t)(oend-op)); ++ { BYTE* op = ostart; ++ size_t const initErr = HUF_initCStream(&bitC, op, (size_t)(oend-op)); + if (HUF_isError(initErr)) return 0; } + + if (dstSize < HUF_tightCompressBound(srcSize, (size_t)tableLog) || tableLog > 11) +@@ -1045,9 +1136,9 @@ HUF_compress1X_usingCTable_internal_default(void* dst, size_t dstSize, + static size_t + HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, + const void* src, size_t srcSize, +- const HUF_CElt* CTable, const int bmi2) ++ const HUF_CElt* CTable, const int flags) + { +- if (bmi2) { ++ if (flags & HUF_flags_bmi2) { + return HUF_compress1X_usingCTable_internal_bmi2(dst, dstSize, src, srcSize, CTable); + } + return HUF_compress1X_usingCTable_internal_default(dst, dstSize, src, srcSize, CTable); +@@ -1058,28 +1149,23 @@ HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, + static size_t + HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, + const void* src, size_t srcSize, +- const HUF_CElt* CTable, const int bmi2) ++ const HUF_CElt* CTable, const int flags) + { +- (void)bmi2; ++ (void)flags; + return HUF_compress1X_usingCTable_internal_body(dst, dstSize, src, srcSize, CTable); + } + + #endif + +-size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) ++size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags) + { +- return HUF_compress1X_usingCTable_bmi2(dst, dstSize, src, srcSize, CTable, /* bmi2 */ 0); +-} +- +-size_t HUF_compress1X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2) +-{ +- return HUF_compress1X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, bmi2); ++ return HUF_compress1X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, flags); + } + + static size_t + HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, + const void* src, size_t srcSize, +- const HUF_CElt* CTable, int bmi2) ++ const HUF_CElt* CTable, int flags) + { + size_t const segmentSize = (srcSize+3)/4; /* first 3 segments */ + const BYTE* ip = (const BYTE*) src; +@@ -1093,7 +1179,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, + op += 6; /* jumpTable */ + + assert(op <= oend); +- { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); ++ { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); + if (cSize == 0 || cSize > 65535) return 0; + MEM_writeLE16(ostart, (U16)cSize); + op += cSize; +@@ -1101,7 +1187,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, + + ip += segmentSize; + assert(op <= oend); +- { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); ++ { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); + if (cSize == 0 || cSize > 65535) return 0; + MEM_writeLE16(ostart+2, (U16)cSize); + op += cSize; +@@ -1109,7 +1195,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, + + ip += segmentSize; + assert(op <= oend); +- { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); ++ { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); + if (cSize == 0 || cSize > 65535) return 0; + MEM_writeLE16(ostart+4, (U16)cSize); + op += cSize; +@@ -1118,7 +1204,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, + ip += segmentSize; + assert(op <= oend); + assert(ip <= iend); +- { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, (size_t)(iend-ip), CTable, bmi2) ); ++ { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, (size_t)(iend-ip), CTable, flags) ); + if (cSize == 0 || cSize > 65535) return 0; + op += cSize; + } +@@ -1126,14 +1212,9 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, + return (size_t)(op-ostart); + } + +-size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) +-{ +- return HUF_compress4X_usingCTable_bmi2(dst, dstSize, src, srcSize, CTable, /* bmi2 */ 0); +-} +- +-size_t HUF_compress4X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2) ++size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags) + { +- return HUF_compress4X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, bmi2); ++ return HUF_compress4X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, flags); + } + + typedef enum { HUF_singleStream, HUF_fourStreams } HUF_nbStreams_e; +@@ -1141,11 +1222,11 @@ typedef enum { HUF_singleStream, HUF_fourStreams } HUF_nbStreams_e; + static size_t HUF_compressCTable_internal( + BYTE* const ostart, BYTE* op, BYTE* const oend, + const void* src, size_t srcSize, +- HUF_nbStreams_e nbStreams, const HUF_CElt* CTable, const int bmi2) ++ HUF_nbStreams_e nbStreams, const HUF_CElt* CTable, const int flags) + { + size_t const cSize = (nbStreams==HUF_singleStream) ? +- HUF_compress1X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, bmi2) : +- HUF_compress4X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, bmi2); ++ HUF_compress1X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, flags) : ++ HUF_compress4X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, flags); + if (HUF_isError(cSize)) { return cSize; } + if (cSize==0) { return 0; } /* uncompressible */ + op += cSize; +@@ -1168,6 +1249,81 @@ typedef struct { + #define SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE 4096 + #define SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO 10 /* Must be >= 2 */ + ++unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue) ++{ ++ unsigned cardinality = 0; ++ unsigned i; ++ ++ for (i = 0; i < maxSymbolValue + 1; i++) { ++ if (count[i] != 0) cardinality += 1; ++ } ++ ++ return cardinality; ++} ++ ++unsigned HUF_minTableLog(unsigned symbolCardinality) ++{ ++ U32 minBitsSymbols = ZSTD_highbit32(symbolCardinality) + 1; ++ return minBitsSymbols; ++} ++ ++unsigned HUF_optimalTableLog( ++ unsigned maxTableLog, ++ size_t srcSize, ++ unsigned maxSymbolValue, ++ void* workSpace, size_t wkspSize, ++ HUF_CElt* table, ++ const unsigned* count, ++ int flags) ++{ ++ assert(srcSize > 1); /* Not supported, RLE should be used instead */ ++ assert(wkspSize >= sizeof(HUF_buildCTable_wksp_tables)); ++ ++ if (!(flags & HUF_flags_optimalDepth)) { ++ /* cheap evaluation, based on FSE */ ++ return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); ++ } ++ ++ { BYTE* dst = (BYTE*)workSpace + sizeof(HUF_WriteCTableWksp); ++ size_t dstSize = wkspSize - sizeof(HUF_WriteCTableWksp); ++ size_t hSize, newSize; ++ const unsigned symbolCardinality = HUF_cardinality(count, maxSymbolValue); ++ const unsigned minTableLog = HUF_minTableLog(symbolCardinality); ++ size_t optSize = ((size_t) ~0) - 1; ++ unsigned optLog = maxTableLog, optLogGuess; ++ ++ DEBUGLOG(6, "HUF_optimalTableLog: probing huf depth (srcSize=%zu)", srcSize); ++ ++ /* Search until size increases */ ++ for (optLogGuess = minTableLog; optLogGuess <= maxTableLog; optLogGuess++) { ++ DEBUGLOG(7, "checking for huffLog=%u", optLogGuess); ++ ++ { size_t maxBits = HUF_buildCTable_wksp(table, count, maxSymbolValue, optLogGuess, workSpace, wkspSize); ++ if (ERR_isError(maxBits)) continue; ++ ++ if (maxBits < optLogGuess && optLogGuess > minTableLog) break; ++ ++ hSize = HUF_writeCTable_wksp(dst, dstSize, table, maxSymbolValue, (U32)maxBits, workSpace, wkspSize); ++ } ++ ++ if (ERR_isError(hSize)) continue; ++ ++ newSize = HUF_estimateCompressedSize(table, count, maxSymbolValue) + hSize; ++ ++ if (newSize > optSize + 1) { ++ break; ++ } ++ ++ if (newSize < optSize) { ++ optSize = newSize; ++ optLog = optLogGuess; ++ } ++ } ++ assert(optLog <= HUF_TABLELOG_MAX); ++ return optLog; ++ } ++} ++ + /* HUF_compress_internal() : + * `workSpace_align4` must be aligned on 4-bytes boundaries, + * and occupies the same space as a table of HUF_WORKSPACE_SIZE_U64 unsigned */ +@@ -1177,14 +1333,14 @@ HUF_compress_internal (void* dst, size_t dstSize, + unsigned maxSymbolValue, unsigned huffLog, + HUF_nbStreams_e nbStreams, + void* workSpace, size_t wkspSize, +- HUF_CElt* oldHufTable, HUF_repeat* repeat, int preferRepeat, +- const int bmi2, unsigned suspectUncompressible) ++ HUF_CElt* oldHufTable, HUF_repeat* repeat, int flags) + { + HUF_compress_tables_t* const table = (HUF_compress_tables_t*)HUF_alignUpWorkspace(workSpace, &wkspSize, ZSTD_ALIGNOF(size_t)); + BYTE* const ostart = (BYTE*)dst; + BYTE* const oend = ostart + dstSize; + BYTE* op = ostart; + ++ DEBUGLOG(5, "HUF_compress_internal (srcSize=%zu)", srcSize); + HUF_STATIC_ASSERT(sizeof(*table) + HUF_WORKSPACE_MAX_ALIGNMENT <= HUF_WORKSPACE_SIZE); + + /* checks & inits */ +@@ -1198,16 +1354,17 @@ HUF_compress_internal (void* dst, size_t dstSize, + if (!huffLog) huffLog = HUF_TABLELOG_DEFAULT; + + /* Heuristic : If old table is valid, use it for small inputs */ +- if (preferRepeat && repeat && *repeat == HUF_repeat_valid) { ++ if ((flags & HUF_flags_preferRepeat) && repeat && *repeat == HUF_repeat_valid) { + return HUF_compressCTable_internal(ostart, op, oend, + src, srcSize, +- nbStreams, oldHufTable, bmi2); ++ nbStreams, oldHufTable, flags); + } + + /* If uncompressible data is suspected, do a smaller sampling first */ + DEBUG_STATIC_ASSERT(SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO >= 2); +- if (suspectUncompressible && srcSize >= (SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE * SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO)) { ++ if ((flags & HUF_flags_suspectUncompressible) && srcSize >= (SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE * SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO)) { + size_t largestTotal = 0; ++ DEBUGLOG(5, "input suspected incompressible : sampling to check"); + { unsigned maxSymbolValueBegin = maxSymbolValue; + CHECK_V_F(largestBegin, HIST_count_simple (table->count, &maxSymbolValueBegin, (const BYTE*)src, SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE) ); + largestTotal += largestBegin; +@@ -1224,6 +1381,7 @@ HUF_compress_internal (void* dst, size_t dstSize, + if (largest == srcSize) { *ostart = ((const BYTE*)src)[0]; return 1; } /* single symbol, rle */ + if (largest <= (srcSize >> 7)+4) return 0; /* heuristic : probably not compressible enough */ + } ++ DEBUGLOG(6, "histogram detail completed (%zu symbols)", showU32(table->count, maxSymbolValue+1)); + + /* Check validity of previous table */ + if ( repeat +@@ -1232,25 +1390,20 @@ HUF_compress_internal (void* dst, size_t dstSize, + *repeat = HUF_repeat_none; + } + /* Heuristic : use existing table for small inputs */ +- if (preferRepeat && repeat && *repeat != HUF_repeat_none) { ++ if ((flags & HUF_flags_preferRepeat) && repeat && *repeat != HUF_repeat_none) { + return HUF_compressCTable_internal(ostart, op, oend, + src, srcSize, +- nbStreams, oldHufTable, bmi2); ++ nbStreams, oldHufTable, flags); + } + + /* Build Huffman Tree */ +- huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); ++ huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue, &table->wksps, sizeof(table->wksps), table->CTable, table->count, flags); + { size_t const maxBits = HUF_buildCTable_wksp(table->CTable, table->count, + maxSymbolValue, huffLog, + &table->wksps.buildCTable_wksp, sizeof(table->wksps.buildCTable_wksp)); + CHECK_F(maxBits); + huffLog = (U32)maxBits; +- } +- /* Zero unused symbols in CTable, so we can check it for validity */ +- { +- size_t const ctableSize = HUF_CTABLE_SIZE_ST(maxSymbolValue); +- size_t const unusedSize = sizeof(table->CTable) - ctableSize * sizeof(HUF_CElt); +- ZSTD_memset(table->CTable + ctableSize, 0, unusedSize); ++ DEBUGLOG(6, "bit distribution completed (%zu symbols)", showCTableBits(table->CTable + 1, maxSymbolValue+1)); + } + + /* Write table description header */ +@@ -1263,7 +1416,7 @@ HUF_compress_internal (void* dst, size_t dstSize, + if (oldSize <= hSize + newSize || hSize + 12 >= srcSize) { + return HUF_compressCTable_internal(ostart, op, oend, + src, srcSize, +- nbStreams, oldHufTable, bmi2); ++ nbStreams, oldHufTable, flags); + } } + + /* Use the new huffman table */ +@@ -1275,61 +1428,35 @@ HUF_compress_internal (void* dst, size_t dstSize, + } + return HUF_compressCTable_internal(ostart, op, oend, + src, srcSize, +- nbStreams, table->CTable, bmi2); +-} +- +- +-size_t HUF_compress1X_wksp (void* dst, size_t dstSize, +- const void* src, size_t srcSize, +- unsigned maxSymbolValue, unsigned huffLog, +- void* workSpace, size_t wkspSize) +-{ +- return HUF_compress_internal(dst, dstSize, src, srcSize, +- maxSymbolValue, huffLog, HUF_singleStream, +- workSpace, wkspSize, +- NULL, NULL, 0, 0 /*bmi2*/, 0); ++ nbStreams, table->CTable, flags); + } + + size_t HUF_compress1X_repeat (void* dst, size_t dstSize, + const void* src, size_t srcSize, + unsigned maxSymbolValue, unsigned huffLog, + void* workSpace, size_t wkspSize, +- HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, +- int bmi2, unsigned suspectUncompressible) ++ HUF_CElt* hufTable, HUF_repeat* repeat, int flags) + { ++ DEBUGLOG(5, "HUF_compress1X_repeat (srcSize = %zu)", srcSize); + return HUF_compress_internal(dst, dstSize, src, srcSize, + maxSymbolValue, huffLog, HUF_singleStream, + workSpace, wkspSize, hufTable, +- repeat, preferRepeat, bmi2, suspectUncompressible); +-} +- +-/* HUF_compress4X_repeat(): +- * compress input using 4 streams. +- * provide workspace to generate compression tables */ +-size_t HUF_compress4X_wksp (void* dst, size_t dstSize, +- const void* src, size_t srcSize, +- unsigned maxSymbolValue, unsigned huffLog, +- void* workSpace, size_t wkspSize) +-{ +- return HUF_compress_internal(dst, dstSize, src, srcSize, +- maxSymbolValue, huffLog, HUF_fourStreams, +- workSpace, wkspSize, +- NULL, NULL, 0, 0 /*bmi2*/, 0); ++ repeat, flags); + } + + /* HUF_compress4X_repeat(): + * compress input using 4 streams. + * consider skipping quickly +- * re-use an existing huffman compression table */ ++ * reuse an existing huffman compression table */ + size_t HUF_compress4X_repeat (void* dst, size_t dstSize, + const void* src, size_t srcSize, + unsigned maxSymbolValue, unsigned huffLog, + void* workSpace, size_t wkspSize, +- HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible) ++ HUF_CElt* hufTable, HUF_repeat* repeat, int flags) + { ++ DEBUGLOG(5, "HUF_compress4X_repeat (srcSize = %zu)", srcSize); + return HUF_compress_internal(dst, dstSize, src, srcSize, + maxSymbolValue, huffLog, HUF_fourStreams, + workSpace, wkspSize, +- hufTable, repeat, preferRepeat, bmi2, suspectUncompressible); ++ hufTable, repeat, flags); + } +- +diff --git a/lib/zstd/compress/zstd_compress.c b/lib/zstd/compress/zstd_compress.c +index f620cafca633..0d139727cd39 100644 +--- a/lib/zstd/compress/zstd_compress.c ++++ b/lib/zstd/compress/zstd_compress.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -11,12 +12,12 @@ + /*-************************************* + * Dependencies + ***************************************/ ++#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */ + #include "../common/zstd_deps.h" /* INT_MAX, ZSTD_memset, ZSTD_memcpy */ + #include "../common/mem.h" + #include "hist.h" /* HIST_countFast_wksp */ + #define FSE_STATIC_LINKING_ONLY /* FSE_encodeSymbol */ + #include "../common/fse.h" +-#define HUF_STATIC_LINKING_ONLY + #include "../common/huf.h" + #include "zstd_compress_internal.h" + #include "zstd_compress_sequences.h" +@@ -27,6 +28,7 @@ + #include "zstd_opt.h" + #include "zstd_ldm.h" + #include "zstd_compress_superblock.h" ++#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_rotateRight_U64 */ + + /* *************************************************************** + * Tuning parameters +@@ -55,14 +57,17 @@ + * Helper functions + ***************************************/ + /* ZSTD_compressBound() +- * Note that the result from this function is only compatible with the "normal" +- * full-block strategy. +- * When there are a lot of small blocks due to frequent flush in streaming mode +- * the overhead of headers can make the compressed data to be larger than the +- * return value of ZSTD_compressBound(). ++ * Note that the result from this function is only valid for ++ * the one-pass compression functions. ++ * When employing the streaming mode, ++ * if flushes are frequently altering the size of blocks, ++ * the overhead from block headers can make the compressed data larger ++ * than the return value of ZSTD_compressBound(). + */ + size_t ZSTD_compressBound(size_t srcSize) { +- return ZSTD_COMPRESSBOUND(srcSize); ++ size_t const r = ZSTD_COMPRESSBOUND(srcSize); ++ if (r==0) return ERROR(srcSize_wrong); ++ return r; + } + + +@@ -168,15 +173,13 @@ static void ZSTD_freeCCtxContent(ZSTD_CCtx* cctx) + + size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx) + { ++ DEBUGLOG(3, "ZSTD_freeCCtx (address: %p)", (void*)cctx); + if (cctx==NULL) return 0; /* support free on NULL */ + RETURN_ERROR_IF(cctx->staticSize, memory_allocation, + "not compatible with static CCtx"); +- { +- int cctxInWorkspace = ZSTD_cwksp_owns_buffer(&cctx->workspace, cctx); ++ { int cctxInWorkspace = ZSTD_cwksp_owns_buffer(&cctx->workspace, cctx); + ZSTD_freeCCtxContent(cctx); +- if (!cctxInWorkspace) { +- ZSTD_customFree(cctx, cctx->customMem); +- } ++ if (!cctxInWorkspace) ZSTD_customFree(cctx, cctx->customMem); + } + return 0; + } +@@ -257,9 +260,9 @@ static int ZSTD_allocateChainTable(const ZSTD_strategy strategy, + return forDDSDict || ((strategy != ZSTD_fast) && !ZSTD_rowMatchFinderUsed(strategy, useRowMatchFinder)); + } + +-/* Returns 1 if compression parameters are such that we should ++/* Returns ZSTD_ps_enable if compression parameters are such that we should + * enable long distance matching (wlog >= 27, strategy >= btopt). +- * Returns 0 otherwise. ++ * Returns ZSTD_ps_disable otherwise. + */ + static ZSTD_paramSwitch_e ZSTD_resolveEnableLdm(ZSTD_paramSwitch_e mode, + const ZSTD_compressionParameters* const cParams) { +@@ -267,6 +270,34 @@ static ZSTD_paramSwitch_e ZSTD_resolveEnableLdm(ZSTD_paramSwitch_e mode, + return (cParams->strategy >= ZSTD_btopt && cParams->windowLog >= 27) ? ZSTD_ps_enable : ZSTD_ps_disable; + } + ++static int ZSTD_resolveExternalSequenceValidation(int mode) { ++ return mode; ++} ++ ++/* Resolves maxBlockSize to the default if no value is present. */ ++static size_t ZSTD_resolveMaxBlockSize(size_t maxBlockSize) { ++ if (maxBlockSize == 0) { ++ return ZSTD_BLOCKSIZE_MAX; ++ } else { ++ return maxBlockSize; ++ } ++} ++ ++static ZSTD_paramSwitch_e ZSTD_resolveExternalRepcodeSearch(ZSTD_paramSwitch_e value, int cLevel) { ++ if (value != ZSTD_ps_auto) return value; ++ if (cLevel < 10) { ++ return ZSTD_ps_disable; ++ } else { ++ return ZSTD_ps_enable; ++ } ++} ++ ++/* Returns 1 if compression parameters are such that CDict hashtable and chaintable indices are tagged. ++ * If so, the tags need to be removed in ZSTD_resetCCtx_byCopyingCDict. */ ++static int ZSTD_CDictIndicesAreTagged(const ZSTD_compressionParameters* const cParams) { ++ return cParams->strategy == ZSTD_fast || cParams->strategy == ZSTD_dfast; ++} ++ + static ZSTD_CCtx_params ZSTD_makeCCtxParamsFromCParams( + ZSTD_compressionParameters cParams) + { +@@ -284,6 +315,10 @@ static ZSTD_CCtx_params ZSTD_makeCCtxParamsFromCParams( + } + cctxParams.useBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams.useBlockSplitter, &cParams); + cctxParams.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams.useRowMatchFinder, &cParams); ++ cctxParams.validateSequences = ZSTD_resolveExternalSequenceValidation(cctxParams.validateSequences); ++ cctxParams.maxBlockSize = ZSTD_resolveMaxBlockSize(cctxParams.maxBlockSize); ++ cctxParams.searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(cctxParams.searchForExternalRepcodes, ++ cctxParams.compressionLevel); + assert(!ZSTD_checkCParams(cParams)); + return cctxParams; + } +@@ -329,10 +364,13 @@ size_t ZSTD_CCtxParams_init(ZSTD_CCtx_params* cctxParams, int compressionLevel) + #define ZSTD_NO_CLEVEL 0 + + /* +- * Initializes the cctxParams from params and compressionLevel. ++ * Initializes `cctxParams` from `params` and `compressionLevel`. + * @param compressionLevel If params are derived from a compression level then that compression level, otherwise ZSTD_NO_CLEVEL. + */ +-static void ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ZSTD_parameters const* params, int compressionLevel) ++static void ++ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ++ const ZSTD_parameters* params, ++ int compressionLevel) + { + assert(!ZSTD_checkCParams(params->cParams)); + ZSTD_memset(cctxParams, 0, sizeof(*cctxParams)); +@@ -345,6 +383,9 @@ static void ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ZSTD_par + cctxParams->useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams->useRowMatchFinder, ¶ms->cParams); + cctxParams->useBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams->useBlockSplitter, ¶ms->cParams); + cctxParams->ldmParams.enableLdm = ZSTD_resolveEnableLdm(cctxParams->ldmParams.enableLdm, ¶ms->cParams); ++ cctxParams->validateSequences = ZSTD_resolveExternalSequenceValidation(cctxParams->validateSequences); ++ cctxParams->maxBlockSize = ZSTD_resolveMaxBlockSize(cctxParams->maxBlockSize); ++ cctxParams->searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(cctxParams->searchForExternalRepcodes, compressionLevel); + DEBUGLOG(4, "ZSTD_CCtxParams_init_internal: useRowMatchFinder=%d, useBlockSplitter=%d ldm=%d", + cctxParams->useRowMatchFinder, cctxParams->useBlockSplitter, cctxParams->ldmParams.enableLdm); + } +@@ -359,7 +400,7 @@ size_t ZSTD_CCtxParams_init_advanced(ZSTD_CCtx_params* cctxParams, ZSTD_paramete + + /* + * Sets cctxParams' cParams and fParams from params, but otherwise leaves them alone. +- * @param param Validated zstd parameters. ++ * @param params Validated zstd parameters. + */ + static void ZSTD_CCtxParams_setZstdParams( + ZSTD_CCtx_params* cctxParams, const ZSTD_parameters* params) +@@ -455,8 +496,8 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) + return bounds; + + case ZSTD_c_enableLongDistanceMatching: +- bounds.lowerBound = 0; +- bounds.upperBound = 1; ++ bounds.lowerBound = (int)ZSTD_ps_auto; ++ bounds.upperBound = (int)ZSTD_ps_disable; + return bounds; + + case ZSTD_c_ldmHashLog: +@@ -549,6 +590,26 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) + bounds.upperBound = 1; + return bounds; + ++ case ZSTD_c_prefetchCDictTables: ++ bounds.lowerBound = (int)ZSTD_ps_auto; ++ bounds.upperBound = (int)ZSTD_ps_disable; ++ return bounds; ++ ++ case ZSTD_c_enableSeqProducerFallback: ++ bounds.lowerBound = 0; ++ bounds.upperBound = 1; ++ return bounds; ++ ++ case ZSTD_c_maxBlockSize: ++ bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN; ++ bounds.upperBound = ZSTD_BLOCKSIZE_MAX; ++ return bounds; ++ ++ case ZSTD_c_searchForExternalRepcodes: ++ bounds.lowerBound = (int)ZSTD_ps_auto; ++ bounds.upperBound = (int)ZSTD_ps_disable; ++ return bounds; ++ + default: + bounds.error = ERROR(parameter_unsupported); + return bounds; +@@ -567,10 +628,11 @@ static size_t ZSTD_cParam_clampBounds(ZSTD_cParameter cParam, int* value) + return 0; + } + +-#define BOUNDCHECK(cParam, val) { \ +- RETURN_ERROR_IF(!ZSTD_cParam_withinBounds(cParam,val), \ +- parameter_outOfBound, "Param out of bounds"); \ +-} ++#define BOUNDCHECK(cParam, val) \ ++ do { \ ++ RETURN_ERROR_IF(!ZSTD_cParam_withinBounds(cParam,val), \ ++ parameter_outOfBound, "Param out of bounds"); \ ++ } while (0) + + + static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) +@@ -613,6 +675,10 @@ static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) + case ZSTD_c_useBlockSplitter: + case ZSTD_c_useRowMatchFinder: + case ZSTD_c_deterministicRefPrefix: ++ case ZSTD_c_prefetchCDictTables: ++ case ZSTD_c_enableSeqProducerFallback: ++ case ZSTD_c_maxBlockSize: ++ case ZSTD_c_searchForExternalRepcodes: + default: + return 0; + } +@@ -625,7 +691,7 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, int value) + if (ZSTD_isUpdateAuthorized(param)) { + cctx->cParamsChanged = 1; + } else { +- RETURN_ERROR(stage_wrong, "can only set params in ctx init stage"); ++ RETURN_ERROR(stage_wrong, "can only set params in cctx init stage"); + } } + + switch(param) +@@ -668,6 +734,10 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, int value) + case ZSTD_c_useBlockSplitter: + case ZSTD_c_useRowMatchFinder: + case ZSTD_c_deterministicRefPrefix: ++ case ZSTD_c_prefetchCDictTables: ++ case ZSTD_c_enableSeqProducerFallback: ++ case ZSTD_c_maxBlockSize: ++ case ZSTD_c_searchForExternalRepcodes: + break; + + default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); +@@ -723,12 +793,12 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, + case ZSTD_c_minMatch : + if (value!=0) /* 0 => use default */ + BOUNDCHECK(ZSTD_c_minMatch, value); +- CCtxParams->cParams.minMatch = value; ++ CCtxParams->cParams.minMatch = (U32)value; + return CCtxParams->cParams.minMatch; + + case ZSTD_c_targetLength : + BOUNDCHECK(ZSTD_c_targetLength, value); +- CCtxParams->cParams.targetLength = value; ++ CCtxParams->cParams.targetLength = (U32)value; + return CCtxParams->cParams.targetLength; + + case ZSTD_c_strategy : +@@ -741,12 +811,12 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, + /* Content size written in frame header _when known_ (default:1) */ + DEBUGLOG(4, "set content size flag = %u", (value!=0)); + CCtxParams->fParams.contentSizeFlag = value != 0; +- return CCtxParams->fParams.contentSizeFlag; ++ return (size_t)CCtxParams->fParams.contentSizeFlag; + + case ZSTD_c_checksumFlag : + /* A 32-bits content checksum will be calculated and written at end of frame (default:0) */ + CCtxParams->fParams.checksumFlag = value != 0; +- return CCtxParams->fParams.checksumFlag; ++ return (size_t)CCtxParams->fParams.checksumFlag; + + case ZSTD_c_dictIDFlag : /* When applicable, dictionary's dictID is provided in frame header (default:1) */ + DEBUGLOG(4, "set dictIDFlag = %u", (value!=0)); +@@ -755,18 +825,18 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, + + case ZSTD_c_forceMaxWindow : + CCtxParams->forceWindow = (value != 0); +- return CCtxParams->forceWindow; ++ return (size_t)CCtxParams->forceWindow; + + case ZSTD_c_forceAttachDict : { + const ZSTD_dictAttachPref_e pref = (ZSTD_dictAttachPref_e)value; +- BOUNDCHECK(ZSTD_c_forceAttachDict, pref); ++ BOUNDCHECK(ZSTD_c_forceAttachDict, (int)pref); + CCtxParams->attachDictPref = pref; + return CCtxParams->attachDictPref; + } + + case ZSTD_c_literalCompressionMode : { + const ZSTD_paramSwitch_e lcm = (ZSTD_paramSwitch_e)value; +- BOUNDCHECK(ZSTD_c_literalCompressionMode, lcm); ++ BOUNDCHECK(ZSTD_c_literalCompressionMode, (int)lcm); + CCtxParams->literalCompressionMode = lcm; + return CCtxParams->literalCompressionMode; + } +@@ -789,47 +859,50 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, + + case ZSTD_c_enableDedicatedDictSearch : + CCtxParams->enableDedicatedDictSearch = (value!=0); +- return CCtxParams->enableDedicatedDictSearch; ++ return (size_t)CCtxParams->enableDedicatedDictSearch; + + case ZSTD_c_enableLongDistanceMatching : ++ BOUNDCHECK(ZSTD_c_enableLongDistanceMatching, value); + CCtxParams->ldmParams.enableLdm = (ZSTD_paramSwitch_e)value; + return CCtxParams->ldmParams.enableLdm; + + case ZSTD_c_ldmHashLog : + if (value!=0) /* 0 ==> auto */ + BOUNDCHECK(ZSTD_c_ldmHashLog, value); +- CCtxParams->ldmParams.hashLog = value; ++ CCtxParams->ldmParams.hashLog = (U32)value; + return CCtxParams->ldmParams.hashLog; + + case ZSTD_c_ldmMinMatch : + if (value!=0) /* 0 ==> default */ + BOUNDCHECK(ZSTD_c_ldmMinMatch, value); +- CCtxParams->ldmParams.minMatchLength = value; ++ CCtxParams->ldmParams.minMatchLength = (U32)value; + return CCtxParams->ldmParams.minMatchLength; + + case ZSTD_c_ldmBucketSizeLog : + if (value!=0) /* 0 ==> default */ + BOUNDCHECK(ZSTD_c_ldmBucketSizeLog, value); +- CCtxParams->ldmParams.bucketSizeLog = value; ++ CCtxParams->ldmParams.bucketSizeLog = (U32)value; + return CCtxParams->ldmParams.bucketSizeLog; + + case ZSTD_c_ldmHashRateLog : + if (value!=0) /* 0 ==> default */ + BOUNDCHECK(ZSTD_c_ldmHashRateLog, value); +- CCtxParams->ldmParams.hashRateLog = value; ++ CCtxParams->ldmParams.hashRateLog = (U32)value; + return CCtxParams->ldmParams.hashRateLog; + + case ZSTD_c_targetCBlockSize : +- if (value!=0) /* 0 ==> default */ ++ if (value!=0) { /* 0 ==> default */ ++ value = MAX(value, ZSTD_TARGETCBLOCKSIZE_MIN); + BOUNDCHECK(ZSTD_c_targetCBlockSize, value); +- CCtxParams->targetCBlockSize = value; ++ } ++ CCtxParams->targetCBlockSize = (U32)value; + return CCtxParams->targetCBlockSize; + + case ZSTD_c_srcSizeHint : + if (value!=0) /* 0 ==> default */ + BOUNDCHECK(ZSTD_c_srcSizeHint, value); + CCtxParams->srcSizeHint = value; +- return CCtxParams->srcSizeHint; ++ return (size_t)CCtxParams->srcSizeHint; + + case ZSTD_c_stableInBuffer: + BOUNDCHECK(ZSTD_c_stableInBuffer, value); +@@ -849,7 +922,7 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, + case ZSTD_c_validateSequences: + BOUNDCHECK(ZSTD_c_validateSequences, value); + CCtxParams->validateSequences = value; +- return CCtxParams->validateSequences; ++ return (size_t)CCtxParams->validateSequences; + + case ZSTD_c_useBlockSplitter: + BOUNDCHECK(ZSTD_c_useBlockSplitter, value); +@@ -864,7 +937,28 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, + case ZSTD_c_deterministicRefPrefix: + BOUNDCHECK(ZSTD_c_deterministicRefPrefix, value); + CCtxParams->deterministicRefPrefix = !!value; +- return CCtxParams->deterministicRefPrefix; ++ return (size_t)CCtxParams->deterministicRefPrefix; ++ ++ case ZSTD_c_prefetchCDictTables: ++ BOUNDCHECK(ZSTD_c_prefetchCDictTables, value); ++ CCtxParams->prefetchCDictTables = (ZSTD_paramSwitch_e)value; ++ return CCtxParams->prefetchCDictTables; ++ ++ case ZSTD_c_enableSeqProducerFallback: ++ BOUNDCHECK(ZSTD_c_enableSeqProducerFallback, value); ++ CCtxParams->enableMatchFinderFallback = value; ++ return (size_t)CCtxParams->enableMatchFinderFallback; ++ ++ case ZSTD_c_maxBlockSize: ++ if (value!=0) /* 0 ==> default */ ++ BOUNDCHECK(ZSTD_c_maxBlockSize, value); ++ CCtxParams->maxBlockSize = value; ++ return CCtxParams->maxBlockSize; ++ ++ case ZSTD_c_searchForExternalRepcodes: ++ BOUNDCHECK(ZSTD_c_searchForExternalRepcodes, value); ++ CCtxParams->searchForExternalRepcodes = (ZSTD_paramSwitch_e)value; ++ return CCtxParams->searchForExternalRepcodes; + + default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); + } +@@ -980,6 +1074,18 @@ size_t ZSTD_CCtxParams_getParameter( + case ZSTD_c_deterministicRefPrefix: + *value = (int)CCtxParams->deterministicRefPrefix; + break; ++ case ZSTD_c_prefetchCDictTables: ++ *value = (int)CCtxParams->prefetchCDictTables; ++ break; ++ case ZSTD_c_enableSeqProducerFallback: ++ *value = CCtxParams->enableMatchFinderFallback; ++ break; ++ case ZSTD_c_maxBlockSize: ++ *value = (int)CCtxParams->maxBlockSize; ++ break; ++ case ZSTD_c_searchForExternalRepcodes: ++ *value = (int)CCtxParams->searchForExternalRepcodes; ++ break; + default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); + } + return 0; +@@ -1006,9 +1112,47 @@ size_t ZSTD_CCtx_setParametersUsingCCtxParams( + return 0; + } + ++size_t ZSTD_CCtx_setCParams(ZSTD_CCtx* cctx, ZSTD_compressionParameters cparams) ++{ ++ ZSTD_STATIC_ASSERT(sizeof(cparams) == 7 * 4 /* all params are listed below */); ++ DEBUGLOG(4, "ZSTD_CCtx_setCParams"); ++ /* only update if all parameters are valid */ ++ FORWARD_IF_ERROR(ZSTD_checkCParams(cparams), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_windowLog, cparams.windowLog), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_chainLog, cparams.chainLog), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_hashLog, cparams.hashLog), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_searchLog, cparams.searchLog), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_minMatch, cparams.minMatch), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_targetLength, cparams.targetLength), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_strategy, cparams.strategy), ""); ++ return 0; ++} ++ ++size_t ZSTD_CCtx_setFParams(ZSTD_CCtx* cctx, ZSTD_frameParameters fparams) ++{ ++ ZSTD_STATIC_ASSERT(sizeof(fparams) == 3 * 4 /* all params are listed below */); ++ DEBUGLOG(4, "ZSTD_CCtx_setFParams"); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_contentSizeFlag, fparams.contentSizeFlag != 0), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, fparams.checksumFlag != 0), ""); ++ FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_dictIDFlag, fparams.noDictIDFlag == 0), ""); ++ return 0; ++} ++ ++size_t ZSTD_CCtx_setParams(ZSTD_CCtx* cctx, ZSTD_parameters params) ++{ ++ DEBUGLOG(4, "ZSTD_CCtx_setParams"); ++ /* First check cParams, because we want to update all or none. */ ++ FORWARD_IF_ERROR(ZSTD_checkCParams(params.cParams), ""); ++ /* Next set fParams, because this could fail if the cctx isn't in init stage. */ ++ FORWARD_IF_ERROR(ZSTD_CCtx_setFParams(cctx, params.fParams), ""); ++ /* Finally set cParams, which should succeed. */ ++ FORWARD_IF_ERROR(ZSTD_CCtx_setCParams(cctx, params.cParams), ""); ++ return 0; ++} ++ + size_t ZSTD_CCtx_setPledgedSrcSize(ZSTD_CCtx* cctx, unsigned long long pledgedSrcSize) + { +- DEBUGLOG(4, "ZSTD_CCtx_setPledgedSrcSize to %u bytes", (U32)pledgedSrcSize); ++ DEBUGLOG(4, "ZSTD_CCtx_setPledgedSrcSize to %llu bytes", pledgedSrcSize); + RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, + "Can't set pledgedSrcSize when not in init stage."); + cctx->pledgedSrcSizePlusOne = pledgedSrcSize+1; +@@ -1024,9 +1168,9 @@ static void ZSTD_dedicatedDictSearch_revertCParams( + ZSTD_compressionParameters* cParams); + + /* +- * Initializes the local dict using the requested parameters. +- * NOTE: This does not use the pledged src size, because it may be used for more +- * than one compression. ++ * Initializes the local dictionary using requested parameters. ++ * NOTE: Initialization does not employ the pledged src size, ++ * because the dictionary may be used for multiple compressions. + */ + static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) + { +@@ -1039,8 +1183,8 @@ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) + return 0; + } + if (dl->cdict != NULL) { +- assert(cctx->cdict == dl->cdict); + /* Local dictionary already initialized. */ ++ assert(cctx->cdict == dl->cdict); + return 0; + } + assert(dl->dictSize > 0); +@@ -1060,26 +1204,30 @@ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) + } + + size_t ZSTD_CCtx_loadDictionary_advanced( +- ZSTD_CCtx* cctx, const void* dict, size_t dictSize, +- ZSTD_dictLoadMethod_e dictLoadMethod, ZSTD_dictContentType_e dictContentType) ++ ZSTD_CCtx* cctx, ++ const void* dict, size_t dictSize, ++ ZSTD_dictLoadMethod_e dictLoadMethod, ++ ZSTD_dictContentType_e dictContentType) + { +- RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, +- "Can't load a dictionary when ctx is not in init stage."); + DEBUGLOG(4, "ZSTD_CCtx_loadDictionary_advanced (size: %u)", (U32)dictSize); +- ZSTD_clearAllDicts(cctx); /* in case one already exists */ +- if (dict == NULL || dictSize == 0) /* no dictionary mode */ ++ RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, ++ "Can't load a dictionary when cctx is not in init stage."); ++ ZSTD_clearAllDicts(cctx); /* erase any previously set dictionary */ ++ if (dict == NULL || dictSize == 0) /* no dictionary */ + return 0; + if (dictLoadMethod == ZSTD_dlm_byRef) { + cctx->localDict.dict = dict; + } else { ++ /* copy dictionary content inside CCtx to own its lifetime */ + void* dictBuffer; + RETURN_ERROR_IF(cctx->staticSize, memory_allocation, +- "no malloc for static CCtx"); ++ "static CCtx can't allocate for an internal copy of dictionary"); + dictBuffer = ZSTD_customMalloc(dictSize, cctx->customMem); +- RETURN_ERROR_IF(!dictBuffer, memory_allocation, "NULL pointer!"); ++ RETURN_ERROR_IF(dictBuffer==NULL, memory_allocation, ++ "allocation failed for dictionary content"); + ZSTD_memcpy(dictBuffer, dict, dictSize); +- cctx->localDict.dictBuffer = dictBuffer; +- cctx->localDict.dict = dictBuffer; ++ cctx->localDict.dictBuffer = dictBuffer; /* owned ptr to free */ ++ cctx->localDict.dict = dictBuffer; /* read-only reference */ + } + cctx->localDict.dictSize = dictSize; + cctx->localDict.dictContentType = dictContentType; +@@ -1149,7 +1297,7 @@ size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset) + if ( (reset == ZSTD_reset_parameters) + || (reset == ZSTD_reset_session_and_parameters) ) { + RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, +- "Can't reset parameters only when not in init stage."); ++ "Reset parameters is only possible during init stage."); + ZSTD_clearAllDicts(cctx); + return ZSTD_CCtxParams_reset(&cctx->requestedParams); + } +@@ -1178,11 +1326,12 @@ size_t ZSTD_checkCParams(ZSTD_compressionParameters cParams) + static ZSTD_compressionParameters + ZSTD_clampCParams(ZSTD_compressionParameters cParams) + { +-# define CLAMP_TYPE(cParam, val, type) { \ +- ZSTD_bounds const bounds = ZSTD_cParam_getBounds(cParam); \ +- if ((int)valbounds.upperBound) val=(type)bounds.upperBound; \ +- } ++# define CLAMP_TYPE(cParam, val, type) \ ++ do { \ ++ ZSTD_bounds const bounds = ZSTD_cParam_getBounds(cParam); \ ++ if ((int)valbounds.upperBound) val=(type)bounds.upperBound; \ ++ } while (0) + # define CLAMP(cParam, val) CLAMP_TYPE(cParam, val, unsigned) + CLAMP(ZSTD_c_windowLog, cParams.windowLog); + CLAMP(ZSTD_c_chainLog, cParams.chainLog); +@@ -1247,12 +1396,55 @@ static ZSTD_compressionParameters + ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, + unsigned long long srcSize, + size_t dictSize, +- ZSTD_cParamMode_e mode) ++ ZSTD_cParamMode_e mode, ++ ZSTD_paramSwitch_e useRowMatchFinder) + { + const U64 minSrcSize = 513; /* (1<<9) + 1 */ + const U64 maxWindowResize = 1ULL << (ZSTD_WINDOWLOG_MAX-1); + assert(ZSTD_checkCParams(cPar)==0); + ++ /* Cascade the selected strategy down to the next-highest one built into ++ * this binary. */ ++#ifdef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_btultra2) { ++ cPar.strategy = ZSTD_btultra; ++ } ++ if (cPar.strategy == ZSTD_btultra) { ++ cPar.strategy = ZSTD_btopt; ++ } ++#endif ++#ifdef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_btopt) { ++ cPar.strategy = ZSTD_btlazy2; ++ } ++#endif ++#ifdef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_btlazy2) { ++ cPar.strategy = ZSTD_lazy2; ++ } ++#endif ++#ifdef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_lazy2) { ++ cPar.strategy = ZSTD_lazy; ++ } ++#endif ++#ifdef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_lazy) { ++ cPar.strategy = ZSTD_greedy; ++ } ++#endif ++#ifdef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_greedy) { ++ cPar.strategy = ZSTD_dfast; ++ } ++#endif ++#ifdef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR ++ if (cPar.strategy == ZSTD_dfast) { ++ cPar.strategy = ZSTD_fast; ++ cPar.targetLength = 0; ++ } ++#endif ++ + switch (mode) { + case ZSTD_cpm_unknown: + case ZSTD_cpm_noAttachDict: +@@ -1281,8 +1473,8 @@ ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, + } + + /* resize windowLog if input is small enough, to use less memory */ +- if ( (srcSize < maxWindowResize) +- && (dictSize < maxWindowResize) ) { ++ if ( (srcSize <= maxWindowResize) ++ && (dictSize <= maxWindowResize) ) { + U32 const tSize = (U32)(srcSize + dictSize); + static U32 const hashSizeMin = 1 << ZSTD_HASHLOG_MIN; + U32 const srcLog = (tSize < hashSizeMin) ? ZSTD_HASHLOG_MIN : +@@ -1300,6 +1492,42 @@ ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, + if (cPar.windowLog < ZSTD_WINDOWLOG_ABSOLUTEMIN) + cPar.windowLog = ZSTD_WINDOWLOG_ABSOLUTEMIN; /* minimum wlog required for valid frame header */ + ++ /* We can't use more than 32 bits of hash in total, so that means that we require: ++ * (hashLog + 8) <= 32 && (chainLog + 8) <= 32 ++ */ ++ if (mode == ZSTD_cpm_createCDict && ZSTD_CDictIndicesAreTagged(&cPar)) { ++ U32 const maxShortCacheHashLog = 32 - ZSTD_SHORT_CACHE_TAG_BITS; ++ if (cPar.hashLog > maxShortCacheHashLog) { ++ cPar.hashLog = maxShortCacheHashLog; ++ } ++ if (cPar.chainLog > maxShortCacheHashLog) { ++ cPar.chainLog = maxShortCacheHashLog; ++ } ++ } ++ ++ ++ /* At this point, we aren't 100% sure if we are using the row match finder. ++ * Unless it is explicitly disabled, conservatively assume that it is enabled. ++ * In this case it will only be disabled for small sources, so shrinking the ++ * hash log a little bit shouldn't result in any ratio loss. ++ */ ++ if (useRowMatchFinder == ZSTD_ps_auto) ++ useRowMatchFinder = ZSTD_ps_enable; ++ ++ /* We can't hash more than 32-bits in total. So that means that we require: ++ * (hashLog - rowLog + 8) <= 32 ++ */ ++ if (ZSTD_rowMatchFinderUsed(cPar.strategy, useRowMatchFinder)) { ++ /* Switch to 32-entry rows if searchLog is 5 (or more) */ ++ U32 const rowLog = BOUNDED(4, cPar.searchLog, 6); ++ U32 const maxRowHashLog = 32 - ZSTD_ROW_HASH_TAG_BITS; ++ U32 const maxHashLog = maxRowHashLog + rowLog; ++ assert(cPar.hashLog >= rowLog); ++ if (cPar.hashLog > maxHashLog) { ++ cPar.hashLog = maxHashLog; ++ } ++ } ++ + return cPar; + } + +@@ -1310,7 +1538,7 @@ ZSTD_adjustCParams(ZSTD_compressionParameters cPar, + { + cPar = ZSTD_clampCParams(cPar); /* resulting cPar is necessarily valid (all parameters within range) */ + if (srcSize == 0) srcSize = ZSTD_CONTENTSIZE_UNKNOWN; +- return ZSTD_adjustCParams_internal(cPar, srcSize, dictSize, ZSTD_cpm_unknown); ++ return ZSTD_adjustCParams_internal(cPar, srcSize, dictSize, ZSTD_cpm_unknown, ZSTD_ps_auto); + } + + static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode); +@@ -1341,7 +1569,7 @@ ZSTD_compressionParameters ZSTD_getCParamsFromCCtxParams( + ZSTD_overrideCParams(&cParams, &CCtxParams->cParams); + assert(!ZSTD_checkCParams(cParams)); + /* srcSizeHint == 0 means 0 */ +- return ZSTD_adjustCParams_internal(cParams, srcSizeHint, dictSize, mode); ++ return ZSTD_adjustCParams_internal(cParams, srcSizeHint, dictSize, mode, CCtxParams->useRowMatchFinder); + } + + static size_t +@@ -1367,10 +1595,10 @@ ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, + + ZSTD_cwksp_aligned_alloc_size((MaxLL+1) * sizeof(U32)) + + ZSTD_cwksp_aligned_alloc_size((MaxOff+1) * sizeof(U32)) + + ZSTD_cwksp_aligned_alloc_size((1<strategy, useRowMatchFinder) +- ? ZSTD_cwksp_aligned_alloc_size(hSize*sizeof(U16)) ++ ? ZSTD_cwksp_aligned_alloc_size(hSize) + : 0; + size_t const optSpace = (forCCtx && (cParams->strategy >= ZSTD_btopt)) + ? optPotentialSpace +@@ -1386,6 +1614,13 @@ ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, + return tableSpace + optSpace + slackSpace + lazyAdditionalSpace; + } + ++/* Helper function for calculating memory requirements. ++ * Gives a tighter bound than ZSTD_sequenceBound() by taking minMatch into account. */ ++static size_t ZSTD_maxNbSeq(size_t blockSize, unsigned minMatch, int useSequenceProducer) { ++ U32 const divider = (minMatch==3 || useSequenceProducer) ? 3 : 4; ++ return blockSize / divider; ++} ++ + static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( + const ZSTD_compressionParameters* cParams, + const ldmParams_t* ldmParams, +@@ -1393,12 +1628,13 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( + const ZSTD_paramSwitch_e useRowMatchFinder, + const size_t buffInSize, + const size_t buffOutSize, +- const U64 pledgedSrcSize) ++ const U64 pledgedSrcSize, ++ int useSequenceProducer, ++ size_t maxBlockSize) + { + size_t const windowSize = (size_t) BOUNDED(1ULL, 1ULL << cParams->windowLog, pledgedSrcSize); +- size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, windowSize); +- U32 const divider = (cParams->minMatch==3) ? 3 : 4; +- size_t const maxNbSeq = blockSize / divider; ++ size_t const blockSize = MIN(ZSTD_resolveMaxBlockSize(maxBlockSize), windowSize); ++ size_t const maxNbSeq = ZSTD_maxNbSeq(blockSize, cParams->minMatch, useSequenceProducer); + size_t const tokenSpace = ZSTD_cwksp_alloc_size(WILDCOPY_OVERLENGTH + blockSize) + + ZSTD_cwksp_aligned_alloc_size(maxNbSeq * sizeof(seqDef)) + + 3 * ZSTD_cwksp_alloc_size(maxNbSeq * sizeof(BYTE)); +@@ -1417,6 +1653,11 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( + + size_t const cctxSpace = isStatic ? ZSTD_cwksp_alloc_size(sizeof(ZSTD_CCtx)) : 0; + ++ size_t const maxNbExternalSeq = ZSTD_sequenceBound(blockSize); ++ size_t const externalSeqSpace = useSequenceProducer ++ ? ZSTD_cwksp_aligned_alloc_size(maxNbExternalSeq * sizeof(ZSTD_Sequence)) ++ : 0; ++ + size_t const neededSpace = + cctxSpace + + entropySpace + +@@ -1425,7 +1666,8 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( + ldmSeqSpace + + matchStateSize + + tokenSpace + +- bufferSpace; ++ bufferSpace + ++ externalSeqSpace; + + DEBUGLOG(5, "estimate workspace : %u", (U32)neededSpace); + return neededSpace; +@@ -1443,7 +1685,7 @@ size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params) + * be needed. However, we still allocate two 0-sized buffers, which can + * take space under ASAN. */ + return ZSTD_estimateCCtxSize_usingCCtxParams_internal( +- &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, 0, 0, ZSTD_CONTENTSIZE_UNKNOWN); ++ &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, 0, 0, ZSTD_CONTENTSIZE_UNKNOWN, ZSTD_hasExtSeqProd(params), params->maxBlockSize); + } + + size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams) +@@ -1493,7 +1735,7 @@ size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params) + RETURN_ERROR_IF(params->nbWorkers > 0, GENERIC, "Estimate CCtx size is supported for single-threaded compression only."); + { ZSTD_compressionParameters const cParams = + ZSTD_getCParamsFromCCtxParams(params, ZSTD_CONTENTSIZE_UNKNOWN, 0, ZSTD_cpm_noAttachDict); +- size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, (size_t)1 << cParams.windowLog); ++ size_t const blockSize = MIN(ZSTD_resolveMaxBlockSize(params->maxBlockSize), (size_t)1 << cParams.windowLog); + size_t const inBuffSize = (params->inBufferMode == ZSTD_bm_buffered) + ? ((size_t)1 << cParams.windowLog) + blockSize + : 0; +@@ -1504,7 +1746,7 @@ size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params) + + return ZSTD_estimateCCtxSize_usingCCtxParams_internal( + &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, inBuffSize, outBuffSize, +- ZSTD_CONTENTSIZE_UNKNOWN); ++ ZSTD_CONTENTSIZE_UNKNOWN, ZSTD_hasExtSeqProd(params), params->maxBlockSize); + } + } + +@@ -1637,6 +1879,19 @@ typedef enum { + ZSTD_resetTarget_CCtx + } ZSTD_resetTarget_e; + ++/* Mixes bits in a 64 bits in a value, based on XXH3_rrmxmx */ ++static U64 ZSTD_bitmix(U64 val, U64 len) { ++ val ^= ZSTD_rotateRight_U64(val, 49) ^ ZSTD_rotateRight_U64(val, 24); ++ val *= 0x9FB21C651E98DF25ULL; ++ val ^= (val >> 35) + len ; ++ val *= 0x9FB21C651E98DF25ULL; ++ return val ^ (val >> 28); ++} ++ ++/* Mixes in the hashSalt and hashSaltEntropy to create a new hashSalt */ ++static void ZSTD_advanceHashSalt(ZSTD_matchState_t* ms) { ++ ms->hashSalt = ZSTD_bitmix(ms->hashSalt, 8) ^ ZSTD_bitmix((U64) ms->hashSaltEntropy, 4); ++} + + static size_t + ZSTD_reset_matchState(ZSTD_matchState_t* ms, +@@ -1664,6 +1919,7 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, + } + + ms->hashLog3 = hashLog3; ++ ms->lazySkipping = 0; + + ZSTD_invalidateMatchState(ms); + +@@ -1685,22 +1941,19 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, + ZSTD_cwksp_clean_tables(ws); + } + +- /* opt parser space */ +- if ((forWho == ZSTD_resetTarget_CCtx) && (cParams->strategy >= ZSTD_btopt)) { +- DEBUGLOG(4, "reserving optimal parser space"); +- ms->opt.litFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (1<opt.litLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxLL+1) * sizeof(unsigned)); +- ms->opt.matchLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxML+1) * sizeof(unsigned)); +- ms->opt.offCodeFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxOff+1) * sizeof(unsigned)); +- ms->opt.matchTable = (ZSTD_match_t*)ZSTD_cwksp_reserve_aligned(ws, (ZSTD_OPT_NUM+1) * sizeof(ZSTD_match_t)); +- ms->opt.priceTable = (ZSTD_optimal_t*)ZSTD_cwksp_reserve_aligned(ws, (ZSTD_OPT_NUM+1) * sizeof(ZSTD_optimal_t)); +- } +- + if (ZSTD_rowMatchFinderUsed(cParams->strategy, useRowMatchFinder)) { +- { /* Row match finder needs an additional table of hashes ("tags") */ +- size_t const tagTableSize = hSize*sizeof(U16); +- ms->tagTable = (U16*)ZSTD_cwksp_reserve_aligned(ws, tagTableSize); +- if (ms->tagTable) ZSTD_memset(ms->tagTable, 0, tagTableSize); ++ /* Row match finder needs an additional table of hashes ("tags") */ ++ size_t const tagTableSize = hSize; ++ /* We want to generate a new salt in case we reset a Cctx, but we always want to use ++ * 0 when we reset a Cdict */ ++ if(forWho == ZSTD_resetTarget_CCtx) { ++ ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned_init_once(ws, tagTableSize); ++ ZSTD_advanceHashSalt(ms); ++ } else { ++ /* When we are not salting we want to always memset the memory */ ++ ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned(ws, tagTableSize); ++ ZSTD_memset(ms->tagTable, 0, tagTableSize); ++ ms->hashSalt = 0; + } + { /* Switch to 32-entry rows if searchLog is 5 (or more) */ + U32 const rowLog = BOUNDED(4, cParams->searchLog, 6); +@@ -1709,6 +1962,17 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, + } + } + ++ /* opt parser space */ ++ if ((forWho == ZSTD_resetTarget_CCtx) && (cParams->strategy >= ZSTD_btopt)) { ++ DEBUGLOG(4, "reserving optimal parser space"); ++ ms->opt.litFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (1<opt.litLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxLL+1) * sizeof(unsigned)); ++ ms->opt.matchLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxML+1) * sizeof(unsigned)); ++ ms->opt.offCodeFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxOff+1) * sizeof(unsigned)); ++ ms->opt.matchTable = (ZSTD_match_t*)ZSTD_cwksp_reserve_aligned(ws, ZSTD_OPT_SIZE * sizeof(ZSTD_match_t)); ++ ms->opt.priceTable = (ZSTD_optimal_t*)ZSTD_cwksp_reserve_aligned(ws, ZSTD_OPT_SIZE * sizeof(ZSTD_optimal_t)); ++ } ++ + ms->cParams = *cParams; + + RETURN_ERROR_IF(ZSTD_cwksp_reserve_failed(ws), memory_allocation, +@@ -1768,6 +2032,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + assert(params->useRowMatchFinder != ZSTD_ps_auto); + assert(params->useBlockSplitter != ZSTD_ps_auto); + assert(params->ldmParams.enableLdm != ZSTD_ps_auto); ++ assert(params->maxBlockSize != 0); + if (params->ldmParams.enableLdm == ZSTD_ps_enable) { + /* Adjust long distance matching parameters */ + ZSTD_ldm_adjustParameters(&zc->appliedParams.ldmParams, ¶ms->cParams); +@@ -1776,9 +2041,8 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + } + + { size_t const windowSize = MAX(1, (size_t)MIN(((U64)1 << params->cParams.windowLog), pledgedSrcSize)); +- size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, windowSize); +- U32 const divider = (params->cParams.minMatch==3) ? 3 : 4; +- size_t const maxNbSeq = blockSize / divider; ++ size_t const blockSize = MIN(params->maxBlockSize, windowSize); ++ size_t const maxNbSeq = ZSTD_maxNbSeq(blockSize, params->cParams.minMatch, ZSTD_hasExtSeqProd(params)); + size_t const buffOutSize = (zbuff == ZSTDb_buffered && params->outBufferMode == ZSTD_bm_buffered) + ? ZSTD_compressBound(blockSize) + 1 + : 0; +@@ -1795,8 +2059,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + size_t const neededSpace = + ZSTD_estimateCCtxSize_usingCCtxParams_internal( + ¶ms->cParams, ¶ms->ldmParams, zc->staticSize != 0, params->useRowMatchFinder, +- buffInSize, buffOutSize, pledgedSrcSize); +- int resizeWorkspace; ++ buffInSize, buffOutSize, pledgedSrcSize, ZSTD_hasExtSeqProd(params), params->maxBlockSize); + + FORWARD_IF_ERROR(neededSpace, "cctx size estimate failed!"); + +@@ -1805,7 +2068,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + { /* Check if workspace is large enough, alloc a new one if needed */ + int const workspaceTooSmall = ZSTD_cwksp_sizeof(ws) < neededSpace; + int const workspaceWasteful = ZSTD_cwksp_check_wasteful(ws, neededSpace); +- resizeWorkspace = workspaceTooSmall || workspaceWasteful; ++ int resizeWorkspace = workspaceTooSmall || workspaceWasteful; + DEBUGLOG(4, "Need %zu B workspace", neededSpace); + DEBUGLOG(4, "windowSize: %zu - blockSize: %zu", windowSize, blockSize); + +@@ -1838,6 +2101,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + + /* init params */ + zc->blockState.matchState.cParams = params->cParams; ++ zc->blockState.matchState.prefetchCDictTables = params->prefetchCDictTables == ZSTD_ps_enable; + zc->pledgedSrcSizePlusOne = pledgedSrcSize+1; + zc->consumedSrcSize = 0; + zc->producedCSize = 0; +@@ -1854,13 +2118,46 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + + ZSTD_reset_compressedBlockState(zc->blockState.prevCBlock); + ++ FORWARD_IF_ERROR(ZSTD_reset_matchState( ++ &zc->blockState.matchState, ++ ws, ++ ¶ms->cParams, ++ params->useRowMatchFinder, ++ crp, ++ needsIndexReset, ++ ZSTD_resetTarget_CCtx), ""); ++ ++ zc->seqStore.sequencesStart = (seqDef*)ZSTD_cwksp_reserve_aligned(ws, maxNbSeq * sizeof(seqDef)); ++ ++ /* ldm hash table */ ++ if (params->ldmParams.enableLdm == ZSTD_ps_enable) { ++ /* TODO: avoid memset? */ ++ size_t const ldmHSize = ((size_t)1) << params->ldmParams.hashLog; ++ zc->ldmState.hashTable = (ldmEntry_t*)ZSTD_cwksp_reserve_aligned(ws, ldmHSize * sizeof(ldmEntry_t)); ++ ZSTD_memset(zc->ldmState.hashTable, 0, ldmHSize * sizeof(ldmEntry_t)); ++ zc->ldmSequences = (rawSeq*)ZSTD_cwksp_reserve_aligned(ws, maxNbLdmSeq * sizeof(rawSeq)); ++ zc->maxNbLdmSequences = maxNbLdmSeq; ++ ++ ZSTD_window_init(&zc->ldmState.window); ++ zc->ldmState.loadedDictEnd = 0; ++ } ++ ++ /* reserve space for block-level external sequences */ ++ if (ZSTD_hasExtSeqProd(params)) { ++ size_t const maxNbExternalSeq = ZSTD_sequenceBound(blockSize); ++ zc->extSeqBufCapacity = maxNbExternalSeq; ++ zc->extSeqBuf = ++ (ZSTD_Sequence*)ZSTD_cwksp_reserve_aligned(ws, maxNbExternalSeq * sizeof(ZSTD_Sequence)); ++ } ++ ++ /* buffers */ ++ + /* ZSTD_wildcopy() is used to copy into the literals buffer, + * so we have to oversize the buffer by WILDCOPY_OVERLENGTH bytes. + */ + zc->seqStore.litStart = ZSTD_cwksp_reserve_buffer(ws, blockSize + WILDCOPY_OVERLENGTH); + zc->seqStore.maxNbLit = blockSize; + +- /* buffers */ + zc->bufferedPolicy = zbuff; + zc->inBuffSize = buffInSize; + zc->inBuff = (char*)ZSTD_cwksp_reserve_buffer(ws, buffInSize); +@@ -1883,32 +2180,9 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, + zc->seqStore.llCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); + zc->seqStore.mlCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); + zc->seqStore.ofCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); +- zc->seqStore.sequencesStart = (seqDef*)ZSTD_cwksp_reserve_aligned(ws, maxNbSeq * sizeof(seqDef)); +- +- FORWARD_IF_ERROR(ZSTD_reset_matchState( +- &zc->blockState.matchState, +- ws, +- ¶ms->cParams, +- params->useRowMatchFinder, +- crp, +- needsIndexReset, +- ZSTD_resetTarget_CCtx), ""); +- +- /* ldm hash table */ +- if (params->ldmParams.enableLdm == ZSTD_ps_enable) { +- /* TODO: avoid memset? */ +- size_t const ldmHSize = ((size_t)1) << params->ldmParams.hashLog; +- zc->ldmState.hashTable = (ldmEntry_t*)ZSTD_cwksp_reserve_aligned(ws, ldmHSize * sizeof(ldmEntry_t)); +- ZSTD_memset(zc->ldmState.hashTable, 0, ldmHSize * sizeof(ldmEntry_t)); +- zc->ldmSequences = (rawSeq*)ZSTD_cwksp_reserve_aligned(ws, maxNbLdmSeq * sizeof(rawSeq)); +- zc->maxNbLdmSequences = maxNbLdmSeq; +- +- ZSTD_window_init(&zc->ldmState.window); +- zc->ldmState.loadedDictEnd = 0; +- } + + DEBUGLOG(3, "wksp: finished allocating, %zd bytes remain available", ZSTD_cwksp_available_space(ws)); +- assert(ZSTD_cwksp_estimated_space_within_bounds(ws, neededSpace, resizeWorkspace)); ++ assert(ZSTD_cwksp_estimated_space_within_bounds(ws, neededSpace)); + + zc->initialized = 1; + +@@ -1980,7 +2254,8 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx, + } + + params.cParams = ZSTD_adjustCParams_internal(adjusted_cdict_cParams, pledgedSrcSize, +- cdict->dictContentSize, ZSTD_cpm_attachDict); ++ cdict->dictContentSize, ZSTD_cpm_attachDict, ++ params.useRowMatchFinder); + params.cParams.windowLog = windowLog; + params.useRowMatchFinder = cdict->useRowMatchFinder; /* cdict overrides */ + FORWARD_IF_ERROR(ZSTD_resetCCtx_internal(cctx, ¶ms, pledgedSrcSize, +@@ -2019,6 +2294,22 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx, + return 0; + } + ++static void ZSTD_copyCDictTableIntoCCtx(U32* dst, U32 const* src, size_t tableSize, ++ ZSTD_compressionParameters const* cParams) { ++ if (ZSTD_CDictIndicesAreTagged(cParams)){ ++ /* Remove tags from the CDict table if they are present. ++ * See docs on "short cache" in zstd_compress_internal.h for context. */ ++ size_t i; ++ for (i = 0; i < tableSize; i++) { ++ U32 const taggedIndex = src[i]; ++ U32 const index = taggedIndex >> ZSTD_SHORT_CACHE_TAG_BITS; ++ dst[i] = index; ++ } ++ } else { ++ ZSTD_memcpy(dst, src, tableSize * sizeof(U32)); ++ } ++} ++ + static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, + const ZSTD_CDict* cdict, + ZSTD_CCtx_params params, +@@ -2054,21 +2345,23 @@ static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, + : 0; + size_t const hSize = (size_t)1 << cdict_cParams->hashLog; + +- ZSTD_memcpy(cctx->blockState.matchState.hashTable, +- cdict->matchState.hashTable, +- hSize * sizeof(U32)); ++ ZSTD_copyCDictTableIntoCCtx(cctx->blockState.matchState.hashTable, ++ cdict->matchState.hashTable, ++ hSize, cdict_cParams); ++ + /* Do not copy cdict's chainTable if cctx has parameters such that it would not use chainTable */ + if (ZSTD_allocateChainTable(cctx->appliedParams.cParams.strategy, cctx->appliedParams.useRowMatchFinder, 0 /* forDDSDict */)) { +- ZSTD_memcpy(cctx->blockState.matchState.chainTable, +- cdict->matchState.chainTable, +- chainSize * sizeof(U32)); ++ ZSTD_copyCDictTableIntoCCtx(cctx->blockState.matchState.chainTable, ++ cdict->matchState.chainTable, ++ chainSize, cdict_cParams); + } + /* copy tag table */ + if (ZSTD_rowMatchFinderUsed(cdict_cParams->strategy, cdict->useRowMatchFinder)) { +- size_t const tagTableSize = hSize*sizeof(U16); ++ size_t const tagTableSize = hSize; + ZSTD_memcpy(cctx->blockState.matchState.tagTable, +- cdict->matchState.tagTable, +- tagTableSize); ++ cdict->matchState.tagTable, ++ tagTableSize); ++ cctx->blockState.matchState.hashSalt = cdict->matchState.hashSalt; + } + } + +@@ -2147,6 +2440,7 @@ static size_t ZSTD_copyCCtx_internal(ZSTD_CCtx* dstCCtx, + params.useBlockSplitter = srcCCtx->appliedParams.useBlockSplitter; + params.ldmParams = srcCCtx->appliedParams.ldmParams; + params.fParams = fParams; ++ params.maxBlockSize = srcCCtx->appliedParams.maxBlockSize; + ZSTD_resetCCtx_internal(dstCCtx, ¶ms, pledgedSrcSize, + /* loadedDictSize */ 0, + ZSTDcrp_leaveDirty, zbuff); +@@ -2294,7 +2588,7 @@ static void ZSTD_reduceIndex (ZSTD_matchState_t* ms, ZSTD_CCtx_params const* par + + /* See doc/zstd_compression_format.md for detailed format description */ + +-void ZSTD_seqToCodes(const seqStore_t* seqStorePtr) ++int ZSTD_seqToCodes(const seqStore_t* seqStorePtr) + { + const seqDef* const sequences = seqStorePtr->sequencesStart; + BYTE* const llCodeTable = seqStorePtr->llCode; +@@ -2302,18 +2596,24 @@ void ZSTD_seqToCodes(const seqStore_t* seqStorePtr) + BYTE* const mlCodeTable = seqStorePtr->mlCode; + U32 const nbSeq = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + U32 u; ++ int longOffsets = 0; + assert(nbSeq <= seqStorePtr->maxNbSeq); + for (u=0; u= STREAM_ACCUMULATOR_MIN)); ++ if (MEM_32bits() && ofCode >= STREAM_ACCUMULATOR_MIN) ++ longOffsets = 1; + } + if (seqStorePtr->longLengthType==ZSTD_llt_literalLength) + llCodeTable[seqStorePtr->longLengthPos] = MaxLL; + if (seqStorePtr->longLengthType==ZSTD_llt_matchLength) + mlCodeTable[seqStorePtr->longLengthPos] = MaxML; ++ return longOffsets; + } + + /* ZSTD_useTargetCBlockSize(): +@@ -2347,6 +2647,7 @@ typedef struct { + U32 MLtype; + size_t size; + size_t lastCountSize; /* Accounts for bug in 1.3.4. More detail in ZSTD_entropyCompressSeqStore_internal() */ ++ int longOffsets; + } ZSTD_symbolEncodingTypeStats_t; + + /* ZSTD_buildSequencesStatistics(): +@@ -2357,11 +2658,13 @@ typedef struct { + * entropyWkspSize must be of size at least ENTROPY_WORKSPACE_SIZE - (MaxSeq + 1)*sizeof(U32) + */ + static ZSTD_symbolEncodingTypeStats_t +-ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, +- const ZSTD_fseCTables_t* prevEntropy, ZSTD_fseCTables_t* nextEntropy, +- BYTE* dst, const BYTE* const dstEnd, +- ZSTD_strategy strategy, unsigned* countWorkspace, +- void* entropyWorkspace, size_t entropyWkspSize) { ++ZSTD_buildSequencesStatistics( ++ const seqStore_t* seqStorePtr, size_t nbSeq, ++ const ZSTD_fseCTables_t* prevEntropy, ZSTD_fseCTables_t* nextEntropy, ++ BYTE* dst, const BYTE* const dstEnd, ++ ZSTD_strategy strategy, unsigned* countWorkspace, ++ void* entropyWorkspace, size_t entropyWkspSize) ++{ + BYTE* const ostart = dst; + const BYTE* const oend = dstEnd; + BYTE* op = ostart; +@@ -2375,7 +2678,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, + + stats.lastCountSize = 0; + /* convert length/distances into codes */ +- ZSTD_seqToCodes(seqStorePtr); ++ stats.longOffsets = ZSTD_seqToCodes(seqStorePtr); + assert(op <= oend); + assert(nbSeq != 0); /* ZSTD_selectEncodingType() divides by nbSeq */ + /* build CTable for Literal Lengths */ +@@ -2480,22 +2783,22 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, + */ + #define SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO 20 + MEM_STATIC size_t +-ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, +- const ZSTD_entropyCTables_t* prevEntropy, +- ZSTD_entropyCTables_t* nextEntropy, +- const ZSTD_CCtx_params* cctxParams, +- void* dst, size_t dstCapacity, +- void* entropyWorkspace, size_t entropyWkspSize, +- const int bmi2) ++ZSTD_entropyCompressSeqStore_internal( ++ const seqStore_t* seqStorePtr, ++ const ZSTD_entropyCTables_t* prevEntropy, ++ ZSTD_entropyCTables_t* nextEntropy, ++ const ZSTD_CCtx_params* cctxParams, ++ void* dst, size_t dstCapacity, ++ void* entropyWorkspace, size_t entropyWkspSize, ++ const int bmi2) + { +- const int longOffsets = cctxParams->cParams.windowLog > STREAM_ACCUMULATOR_MIN; + ZSTD_strategy const strategy = cctxParams->cParams.strategy; + unsigned* count = (unsigned*)entropyWorkspace; + FSE_CTable* CTable_LitLength = nextEntropy->fse.litlengthCTable; + FSE_CTable* CTable_OffsetBits = nextEntropy->fse.offcodeCTable; + FSE_CTable* CTable_MatchLength = nextEntropy->fse.matchlengthCTable; + const seqDef* const sequences = seqStorePtr->sequencesStart; +- const size_t nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; ++ const size_t nbSeq = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + const BYTE* const ofCodeTable = seqStorePtr->ofCode; + const BYTE* const llCodeTable = seqStorePtr->llCode; + const BYTE* const mlCodeTable = seqStorePtr->mlCode; +@@ -2503,29 +2806,31 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, + BYTE* const oend = ostart + dstCapacity; + BYTE* op = ostart; + size_t lastCountSize; ++ int longOffsets = 0; + + entropyWorkspace = count + (MaxSeq + 1); + entropyWkspSize -= (MaxSeq + 1) * sizeof(*count); + +- DEBUGLOG(4, "ZSTD_entropyCompressSeqStore_internal (nbSeq=%zu)", nbSeq); ++ DEBUGLOG(5, "ZSTD_entropyCompressSeqStore_internal (nbSeq=%zu, dstCapacity=%zu)", nbSeq, dstCapacity); + ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<= HUF_WORKSPACE_SIZE); + + /* Compress literals */ + { const BYTE* const literals = seqStorePtr->litStart; +- size_t const numSequences = seqStorePtr->sequences - seqStorePtr->sequencesStart; +- size_t const numLiterals = seqStorePtr->lit - seqStorePtr->litStart; ++ size_t const numSequences = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); ++ size_t const numLiterals = (size_t)(seqStorePtr->lit - seqStorePtr->litStart); + /* Base suspicion of uncompressibility on ratio of literals to sequences */ + unsigned const suspectUncompressible = (numSequences == 0) || (numLiterals / numSequences >= SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO); + size_t const litSize = (size_t)(seqStorePtr->lit - literals); ++ + size_t const cSize = ZSTD_compressLiterals( +- &prevEntropy->huf, &nextEntropy->huf, +- cctxParams->cParams.strategy, +- ZSTD_literalsCompressionIsDisabled(cctxParams), + op, dstCapacity, + literals, litSize, + entropyWorkspace, entropyWkspSize, +- bmi2, suspectUncompressible); ++ &prevEntropy->huf, &nextEntropy->huf, ++ cctxParams->cParams.strategy, ++ ZSTD_literalsCompressionIsDisabled(cctxParams), ++ suspectUncompressible, bmi2); + FORWARD_IF_ERROR(cSize, "ZSTD_compressLiterals failed"); + assert(cSize <= dstCapacity); + op += cSize; +@@ -2551,11 +2856,10 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, + ZSTD_memcpy(&nextEntropy->fse, &prevEntropy->fse, sizeof(prevEntropy->fse)); + return (size_t)(op - ostart); + } +- { +- ZSTD_symbolEncodingTypeStats_t stats; +- BYTE* seqHead = op++; ++ { BYTE* const seqHead = op++; + /* build stats for sequences */ +- stats = ZSTD_buildSequencesStatistics(seqStorePtr, nbSeq, ++ const ZSTD_symbolEncodingTypeStats_t stats = ++ ZSTD_buildSequencesStatistics(seqStorePtr, nbSeq, + &prevEntropy->fse, &nextEntropy->fse, + op, oend, + strategy, count, +@@ -2564,6 +2868,7 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, + *seqHead = (BYTE)((stats.LLtype<<6) + (stats.Offtype<<4) + (stats.MLtype<<2)); + lastCountSize = stats.lastCountSize; + op += stats.size; ++ longOffsets = stats.longOffsets; + } + + { size_t const bitstreamSize = ZSTD_encodeSequences( +@@ -2598,14 +2903,15 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, + } + + MEM_STATIC size_t +-ZSTD_entropyCompressSeqStore(seqStore_t* seqStorePtr, +- const ZSTD_entropyCTables_t* prevEntropy, +- ZSTD_entropyCTables_t* nextEntropy, +- const ZSTD_CCtx_params* cctxParams, +- void* dst, size_t dstCapacity, +- size_t srcSize, +- void* entropyWorkspace, size_t entropyWkspSize, +- int bmi2) ++ZSTD_entropyCompressSeqStore( ++ const seqStore_t* seqStorePtr, ++ const ZSTD_entropyCTables_t* prevEntropy, ++ ZSTD_entropyCTables_t* nextEntropy, ++ const ZSTD_CCtx_params* cctxParams, ++ void* dst, size_t dstCapacity, ++ size_t srcSize, ++ void* entropyWorkspace, size_t entropyWkspSize, ++ int bmi2) + { + size_t const cSize = ZSTD_entropyCompressSeqStore_internal( + seqStorePtr, prevEntropy, nextEntropy, cctxParams, +@@ -2615,15 +2921,21 @@ ZSTD_entropyCompressSeqStore(seqStore_t* seqStorePtr, + /* When srcSize <= dstCapacity, there is enough space to write a raw uncompressed block. + * Since we ran out of space, block must be not compressible, so fall back to raw uncompressed block. + */ +- if ((cSize == ERROR(dstSize_tooSmall)) & (srcSize <= dstCapacity)) ++ if ((cSize == ERROR(dstSize_tooSmall)) & (srcSize <= dstCapacity)) { ++ DEBUGLOG(4, "not enough dstCapacity (%zu) for ZSTD_entropyCompressSeqStore_internal()=> do not compress block", dstCapacity); + return 0; /* block not compressed */ ++ } + FORWARD_IF_ERROR(cSize, "ZSTD_entropyCompressSeqStore_internal failed"); + + /* Check compressibility */ + { size_t const maxCSize = srcSize - ZSTD_minGain(srcSize, cctxParams->cParams.strategy); + if (cSize >= maxCSize) return 0; /* block not compressed */ + } +- DEBUGLOG(4, "ZSTD_entropyCompressSeqStore() cSize: %zu", cSize); ++ DEBUGLOG(5, "ZSTD_entropyCompressSeqStore() cSize: %zu", cSize); ++ /* libzstd decoder before > v1.5.4 is not compatible with compressed blocks of size ZSTD_BLOCKSIZE_MAX exactly. ++ * This restriction is indirectly already fulfilled by respecting ZSTD_minGain() condition above. ++ */ ++ assert(cSize < ZSTD_BLOCKSIZE_MAX); + return cSize; + } + +@@ -2635,40 +2947,43 @@ ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramS + static const ZSTD_blockCompressor blockCompressor[4][ZSTD_STRATEGY_MAX+1] = { + { ZSTD_compressBlock_fast /* default for 0 */, + ZSTD_compressBlock_fast, +- ZSTD_compressBlock_doubleFast, +- ZSTD_compressBlock_greedy, +- ZSTD_compressBlock_lazy, +- ZSTD_compressBlock_lazy2, +- ZSTD_compressBlock_btlazy2, +- ZSTD_compressBlock_btopt, +- ZSTD_compressBlock_btultra, +- ZSTD_compressBlock_btultra2 }, ++ ZSTD_COMPRESSBLOCK_DOUBLEFAST, ++ ZSTD_COMPRESSBLOCK_GREEDY, ++ ZSTD_COMPRESSBLOCK_LAZY, ++ ZSTD_COMPRESSBLOCK_LAZY2, ++ ZSTD_COMPRESSBLOCK_BTLAZY2, ++ ZSTD_COMPRESSBLOCK_BTOPT, ++ ZSTD_COMPRESSBLOCK_BTULTRA, ++ ZSTD_COMPRESSBLOCK_BTULTRA2 ++ }, + { ZSTD_compressBlock_fast_extDict /* default for 0 */, + ZSTD_compressBlock_fast_extDict, +- ZSTD_compressBlock_doubleFast_extDict, +- ZSTD_compressBlock_greedy_extDict, +- ZSTD_compressBlock_lazy_extDict, +- ZSTD_compressBlock_lazy2_extDict, +- ZSTD_compressBlock_btlazy2_extDict, +- ZSTD_compressBlock_btopt_extDict, +- ZSTD_compressBlock_btultra_extDict, +- ZSTD_compressBlock_btultra_extDict }, ++ ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT, ++ ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT, ++ ZSTD_COMPRESSBLOCK_LAZY_EXTDICT, ++ ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT, ++ ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT, ++ ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT, ++ ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT, ++ ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT ++ }, + { ZSTD_compressBlock_fast_dictMatchState /* default for 0 */, + ZSTD_compressBlock_fast_dictMatchState, +- ZSTD_compressBlock_doubleFast_dictMatchState, +- ZSTD_compressBlock_greedy_dictMatchState, +- ZSTD_compressBlock_lazy_dictMatchState, +- ZSTD_compressBlock_lazy2_dictMatchState, +- ZSTD_compressBlock_btlazy2_dictMatchState, +- ZSTD_compressBlock_btopt_dictMatchState, +- ZSTD_compressBlock_btultra_dictMatchState, +- ZSTD_compressBlock_btultra_dictMatchState }, ++ ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE, ++ ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE ++ }, + { NULL /* default for 0 */, + NULL, + NULL, +- ZSTD_compressBlock_greedy_dedicatedDictSearch, +- ZSTD_compressBlock_lazy_dedicatedDictSearch, +- ZSTD_compressBlock_lazy2_dedicatedDictSearch, ++ ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH, ++ ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH, ++ ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH, + NULL, + NULL, + NULL, +@@ -2681,18 +2996,26 @@ ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramS + DEBUGLOG(4, "Selected block compressor: dictMode=%d strat=%d rowMatchfinder=%d", (int)dictMode, (int)strat, (int)useRowMatchFinder); + if (ZSTD_rowMatchFinderUsed(strat, useRowMatchFinder)) { + static const ZSTD_blockCompressor rowBasedBlockCompressors[4][3] = { +- { ZSTD_compressBlock_greedy_row, +- ZSTD_compressBlock_lazy_row, +- ZSTD_compressBlock_lazy2_row }, +- { ZSTD_compressBlock_greedy_extDict_row, +- ZSTD_compressBlock_lazy_extDict_row, +- ZSTD_compressBlock_lazy2_extDict_row }, +- { ZSTD_compressBlock_greedy_dictMatchState_row, +- ZSTD_compressBlock_lazy_dictMatchState_row, +- ZSTD_compressBlock_lazy2_dictMatchState_row }, +- { ZSTD_compressBlock_greedy_dedicatedDictSearch_row, +- ZSTD_compressBlock_lazy_dedicatedDictSearch_row, +- ZSTD_compressBlock_lazy2_dedicatedDictSearch_row } ++ { ++ ZSTD_COMPRESSBLOCK_GREEDY_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY2_ROW ++ }, ++ { ++ ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW ++ }, ++ { ++ ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW ++ }, ++ { ++ ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW, ++ ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW ++ } + }; + DEBUGLOG(4, "Selecting a row-based matchfinder"); + assert(useRowMatchFinder != ZSTD_ps_auto); +@@ -2718,6 +3041,72 @@ void ZSTD_resetSeqStore(seqStore_t* ssPtr) + ssPtr->longLengthType = ZSTD_llt_none; + } + ++/* ZSTD_postProcessSequenceProducerResult() : ++ * Validates and post-processes sequences obtained through the external matchfinder API: ++ * - Checks whether nbExternalSeqs represents an error condition. ++ * - Appends a block delimiter to outSeqs if one is not already present. ++ * See zstd.h for context regarding block delimiters. ++ * Returns the number of sequences after post-processing, or an error code. */ ++static size_t ZSTD_postProcessSequenceProducerResult( ++ ZSTD_Sequence* outSeqs, size_t nbExternalSeqs, size_t outSeqsCapacity, size_t srcSize ++) { ++ RETURN_ERROR_IF( ++ nbExternalSeqs > outSeqsCapacity, ++ sequenceProducer_failed, ++ "External sequence producer returned error code %lu", ++ (unsigned long)nbExternalSeqs ++ ); ++ ++ RETURN_ERROR_IF( ++ nbExternalSeqs == 0 && srcSize > 0, ++ sequenceProducer_failed, ++ "Got zero sequences from external sequence producer for a non-empty src buffer!" ++ ); ++ ++ if (srcSize == 0) { ++ ZSTD_memset(&outSeqs[0], 0, sizeof(ZSTD_Sequence)); ++ return 1; ++ } ++ ++ { ++ ZSTD_Sequence const lastSeq = outSeqs[nbExternalSeqs - 1]; ++ ++ /* We can return early if lastSeq is already a block delimiter. */ ++ if (lastSeq.offset == 0 && lastSeq.matchLength == 0) { ++ return nbExternalSeqs; ++ } ++ ++ /* This error condition is only possible if the external matchfinder ++ * produced an invalid parse, by definition of ZSTD_sequenceBound(). */ ++ RETURN_ERROR_IF( ++ nbExternalSeqs == outSeqsCapacity, ++ sequenceProducer_failed, ++ "nbExternalSeqs == outSeqsCapacity but lastSeq is not a block delimiter!" ++ ); ++ ++ /* lastSeq is not a block delimiter, so we need to append one. */ ++ ZSTD_memset(&outSeqs[nbExternalSeqs], 0, sizeof(ZSTD_Sequence)); ++ return nbExternalSeqs + 1; ++ } ++} ++ ++/* ZSTD_fastSequenceLengthSum() : ++ * Returns sum(litLen) + sum(matchLen) + lastLits for *seqBuf*. ++ * Similar to another function in zstd_compress.c (determine_blockSize), ++ * except it doesn't check for a block delimiter to end summation. ++ * Removing the early exit allows the compiler to auto-vectorize (https://godbolt.org/z/cY1cajz9P). ++ * This function can be deleted and replaced by determine_blockSize after we resolve issue #3456. */ ++static size_t ZSTD_fastSequenceLengthSum(ZSTD_Sequence const* seqBuf, size_t seqBufSize) { ++ size_t matchLenSum, litLenSum, i; ++ matchLenSum = 0; ++ litLenSum = 0; ++ for (i = 0; i < seqBufSize; i++) { ++ litLenSum += seqBuf[i].litLength; ++ matchLenSum += seqBuf[i].matchLength; ++ } ++ return litLenSum + matchLenSum; ++} ++ + typedef enum { ZSTDbss_compress, ZSTDbss_noCompress } ZSTD_buildSeqStore_e; + + static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) +@@ -2727,7 +3116,9 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) + assert(srcSize <= ZSTD_BLOCKSIZE_MAX); + /* Assert that we have correctly flushed the ctx params into the ms's copy */ + ZSTD_assertEqualCParams(zc->appliedParams.cParams, ms->cParams); +- if (srcSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1) { ++ /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding ++ * additional 1. We need to revisit and change this logic to be more consistent */ ++ if (srcSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1+1) { + if (zc->appliedParams.cParams.strategy >= ZSTD_btopt) { + ZSTD_ldm_skipRawSeqStoreBytes(&zc->externSeqStore, srcSize); + } else { +@@ -2763,6 +3154,15 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) + } + if (zc->externSeqStore.pos < zc->externSeqStore.size) { + assert(zc->appliedParams.ldmParams.enableLdm == ZSTD_ps_disable); ++ ++ /* External matchfinder + LDM is technically possible, just not implemented yet. ++ * We need to revisit soon and implement it. */ ++ RETURN_ERROR_IF( ++ ZSTD_hasExtSeqProd(&zc->appliedParams), ++ parameter_combination_unsupported, ++ "Long-distance matching with external sequence producer enabled is not currently supported." ++ ); ++ + /* Updates ldmSeqStore.pos */ + lastLLSize = + ZSTD_ldm_blockCompress(&zc->externSeqStore, +@@ -2774,6 +3174,14 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) + } else if (zc->appliedParams.ldmParams.enableLdm == ZSTD_ps_enable) { + rawSeqStore_t ldmSeqStore = kNullRawSeqStore; + ++ /* External matchfinder + LDM is technically possible, just not implemented yet. ++ * We need to revisit soon and implement it. */ ++ RETURN_ERROR_IF( ++ ZSTD_hasExtSeqProd(&zc->appliedParams), ++ parameter_combination_unsupported, ++ "Long-distance matching with external sequence producer enabled is not currently supported." ++ ); ++ + ldmSeqStore.seq = zc->ldmSequences; + ldmSeqStore.capacity = zc->maxNbLdmSequences; + /* Updates ldmSeqStore.size */ +@@ -2788,10 +3196,74 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) + zc->appliedParams.useRowMatchFinder, + src, srcSize); + assert(ldmSeqStore.pos == ldmSeqStore.size); +- } else { /* not long range mode */ +- ZSTD_blockCompressor const blockCompressor = ZSTD_selectBlockCompressor(zc->appliedParams.cParams.strategy, +- zc->appliedParams.useRowMatchFinder, +- dictMode); ++ } else if (ZSTD_hasExtSeqProd(&zc->appliedParams)) { ++ assert( ++ zc->extSeqBufCapacity >= ZSTD_sequenceBound(srcSize) ++ ); ++ assert(zc->appliedParams.extSeqProdFunc != NULL); ++ ++ { U32 const windowSize = (U32)1 << zc->appliedParams.cParams.windowLog; ++ ++ size_t const nbExternalSeqs = (zc->appliedParams.extSeqProdFunc)( ++ zc->appliedParams.extSeqProdState, ++ zc->extSeqBuf, ++ zc->extSeqBufCapacity, ++ src, srcSize, ++ NULL, 0, /* dict and dictSize, currently not supported */ ++ zc->appliedParams.compressionLevel, ++ windowSize ++ ); ++ ++ size_t const nbPostProcessedSeqs = ZSTD_postProcessSequenceProducerResult( ++ zc->extSeqBuf, ++ nbExternalSeqs, ++ zc->extSeqBufCapacity, ++ srcSize ++ ); ++ ++ /* Return early if there is no error, since we don't need to worry about last literals */ ++ if (!ZSTD_isError(nbPostProcessedSeqs)) { ++ ZSTD_sequencePosition seqPos = {0,0,0}; ++ size_t const seqLenSum = ZSTD_fastSequenceLengthSum(zc->extSeqBuf, nbPostProcessedSeqs); ++ RETURN_ERROR_IF(seqLenSum > srcSize, externalSequences_invalid, "External sequences imply too large a block!"); ++ FORWARD_IF_ERROR( ++ ZSTD_copySequencesToSeqStoreExplicitBlockDelim( ++ zc, &seqPos, ++ zc->extSeqBuf, nbPostProcessedSeqs, ++ src, srcSize, ++ zc->appliedParams.searchForExternalRepcodes ++ ), ++ "Failed to copy external sequences to seqStore!" ++ ); ++ ms->ldmSeqStore = NULL; ++ DEBUGLOG(5, "Copied %lu sequences from external sequence producer to internal seqStore.", (unsigned long)nbExternalSeqs); ++ return ZSTDbss_compress; ++ } ++ ++ /* Propagate the error if fallback is disabled */ ++ if (!zc->appliedParams.enableMatchFinderFallback) { ++ return nbPostProcessedSeqs; ++ } ++ ++ /* Fallback to software matchfinder */ ++ { ZSTD_blockCompressor const blockCompressor = ++ ZSTD_selectBlockCompressor( ++ zc->appliedParams.cParams.strategy, ++ zc->appliedParams.useRowMatchFinder, ++ dictMode); ++ ms->ldmSeqStore = NULL; ++ DEBUGLOG( ++ 5, ++ "External sequence producer returned error code %lu. Falling back to internal parser.", ++ (unsigned long)nbExternalSeqs ++ ); ++ lastLLSize = blockCompressor(ms, &zc->seqStore, zc->blockState.nextCBlock->rep, src, srcSize); ++ } } ++ } else { /* not long range mode and no external matchfinder */ ++ ZSTD_blockCompressor const blockCompressor = ZSTD_selectBlockCompressor( ++ zc->appliedParams.cParams.strategy, ++ zc->appliedParams.useRowMatchFinder, ++ dictMode); + ms->ldmSeqStore = NULL; + lastLLSize = blockCompressor(ms, &zc->seqStore, zc->blockState.nextCBlock->rep, src, srcSize); + } +@@ -2801,29 +3273,38 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) + return ZSTDbss_compress; + } + +-static void ZSTD_copyBlockSequences(ZSTD_CCtx* zc) ++static size_t ZSTD_copyBlockSequences(SeqCollector* seqCollector, const seqStore_t* seqStore, const U32 prevRepcodes[ZSTD_REP_NUM]) + { +- const seqStore_t* seqStore = ZSTD_getSeqStore(zc); +- const seqDef* seqStoreSeqs = seqStore->sequencesStart; +- size_t seqStoreSeqSize = seqStore->sequences - seqStoreSeqs; +- size_t seqStoreLiteralsSize = (size_t)(seqStore->lit - seqStore->litStart); +- size_t literalsRead = 0; +- size_t lastLLSize; ++ const seqDef* inSeqs = seqStore->sequencesStart; ++ const size_t nbInSequences = seqStore->sequences - inSeqs; ++ const size_t nbInLiterals = (size_t)(seqStore->lit - seqStore->litStart); + +- ZSTD_Sequence* outSeqs = &zc->seqCollector.seqStart[zc->seqCollector.seqIndex]; ++ ZSTD_Sequence* outSeqs = seqCollector->seqIndex == 0 ? seqCollector->seqStart : seqCollector->seqStart + seqCollector->seqIndex; ++ const size_t nbOutSequences = nbInSequences + 1; ++ size_t nbOutLiterals = 0; ++ repcodes_t repcodes; + size_t i; +- repcodes_t updatedRepcodes; + +- assert(zc->seqCollector.seqIndex + 1 < zc->seqCollector.maxSequences); +- /* Ensure we have enough space for last literals "sequence" */ +- assert(zc->seqCollector.maxSequences >= seqStoreSeqSize + 1); +- ZSTD_memcpy(updatedRepcodes.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); +- for (i = 0; i < seqStoreSeqSize; ++i) { +- U32 rawOffset = seqStoreSeqs[i].offBase - ZSTD_REP_NUM; +- outSeqs[i].litLength = seqStoreSeqs[i].litLength; +- outSeqs[i].matchLength = seqStoreSeqs[i].mlBase + MINMATCH; ++ /* Bounds check that we have enough space for every input sequence ++ * and the block delimiter ++ */ ++ assert(seqCollector->seqIndex <= seqCollector->maxSequences); ++ RETURN_ERROR_IF( ++ nbOutSequences > (size_t)(seqCollector->maxSequences - seqCollector->seqIndex), ++ dstSize_tooSmall, ++ "Not enough space to copy sequences"); ++ ++ ZSTD_memcpy(&repcodes, prevRepcodes, sizeof(repcodes)); ++ for (i = 0; i < nbInSequences; ++i) { ++ U32 rawOffset; ++ outSeqs[i].litLength = inSeqs[i].litLength; ++ outSeqs[i].matchLength = inSeqs[i].mlBase + MINMATCH; + outSeqs[i].rep = 0; + ++ /* Handle the possible single length >= 64K ++ * There can only be one because we add MINMATCH to every match length, ++ * and blocks are at most 128K. ++ */ + if (i == seqStore->longLengthPos) { + if (seqStore->longLengthType == ZSTD_llt_literalLength) { + outSeqs[i].litLength += 0x10000; +@@ -2832,37 +3313,55 @@ static void ZSTD_copyBlockSequences(ZSTD_CCtx* zc) + } + } + +- if (seqStoreSeqs[i].offBase <= ZSTD_REP_NUM) { +- /* Derive the correct offset corresponding to a repcode */ +- outSeqs[i].rep = seqStoreSeqs[i].offBase; ++ /* Determine the raw offset given the offBase, which may be a repcode. */ ++ if (OFFBASE_IS_REPCODE(inSeqs[i].offBase)) { ++ const U32 repcode = OFFBASE_TO_REPCODE(inSeqs[i].offBase); ++ assert(repcode > 0); ++ outSeqs[i].rep = repcode; + if (outSeqs[i].litLength != 0) { +- rawOffset = updatedRepcodes.rep[outSeqs[i].rep - 1]; ++ rawOffset = repcodes.rep[repcode - 1]; + } else { +- if (outSeqs[i].rep == 3) { +- rawOffset = updatedRepcodes.rep[0] - 1; ++ if (repcode == 3) { ++ assert(repcodes.rep[0] > 1); ++ rawOffset = repcodes.rep[0] - 1; + } else { +- rawOffset = updatedRepcodes.rep[outSeqs[i].rep]; ++ rawOffset = repcodes.rep[repcode]; + } + } ++ } else { ++ rawOffset = OFFBASE_TO_OFFSET(inSeqs[i].offBase); + } + outSeqs[i].offset = rawOffset; +- /* seqStoreSeqs[i].offset == offCode+1, and ZSTD_updateRep() expects offCode +- so we provide seqStoreSeqs[i].offset - 1 */ +- ZSTD_updateRep(updatedRepcodes.rep, +- seqStoreSeqs[i].offBase - 1, +- seqStoreSeqs[i].litLength == 0); +- literalsRead += outSeqs[i].litLength; ++ ++ /* Update repcode history for the sequence */ ++ ZSTD_updateRep(repcodes.rep, ++ inSeqs[i].offBase, ++ inSeqs[i].litLength == 0); ++ ++ nbOutLiterals += outSeqs[i].litLength; + } + /* Insert last literals (if any exist) in the block as a sequence with ml == off == 0. + * If there are no last literals, then we'll emit (of: 0, ml: 0, ll: 0), which is a marker + * for the block boundary, according to the API. + */ +- assert(seqStoreLiteralsSize >= literalsRead); +- lastLLSize = seqStoreLiteralsSize - literalsRead; +- outSeqs[i].litLength = (U32)lastLLSize; +- outSeqs[i].matchLength = outSeqs[i].offset = outSeqs[i].rep = 0; +- seqStoreSeqSize++; +- zc->seqCollector.seqIndex += seqStoreSeqSize; ++ assert(nbInLiterals >= nbOutLiterals); ++ { ++ const size_t lastLLSize = nbInLiterals - nbOutLiterals; ++ outSeqs[nbInSequences].litLength = (U32)lastLLSize; ++ outSeqs[nbInSequences].matchLength = 0; ++ outSeqs[nbInSequences].offset = 0; ++ assert(nbOutSequences == nbInSequences + 1); ++ } ++ seqCollector->seqIndex += nbOutSequences; ++ assert(seqCollector->seqIndex <= seqCollector->maxSequences); ++ ++ return 0; ++} ++ ++size_t ZSTD_sequenceBound(size_t srcSize) { ++ const size_t maxNbSeq = (srcSize / ZSTD_MINMATCH_MIN) + 1; ++ const size_t maxNbDelims = (srcSize / ZSTD_BLOCKSIZE_MAX_MIN) + 1; ++ return maxNbSeq + maxNbDelims; + } + + size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, +@@ -2871,6 +3370,16 @@ size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, + const size_t dstCapacity = ZSTD_compressBound(srcSize); + void* dst = ZSTD_customMalloc(dstCapacity, ZSTD_defaultCMem); + SeqCollector seqCollector; ++ { ++ int targetCBlockSize; ++ FORWARD_IF_ERROR(ZSTD_CCtx_getParameter(zc, ZSTD_c_targetCBlockSize, &targetCBlockSize), ""); ++ RETURN_ERROR_IF(targetCBlockSize != 0, parameter_unsupported, "targetCBlockSize != 0"); ++ } ++ { ++ int nbWorkers; ++ FORWARD_IF_ERROR(ZSTD_CCtx_getParameter(zc, ZSTD_c_nbWorkers, &nbWorkers), ""); ++ RETURN_ERROR_IF(nbWorkers != 0, parameter_unsupported, "nbWorkers != 0"); ++ } + + RETURN_ERROR_IF(dst == NULL, memory_allocation, "NULL pointer!"); + +@@ -2880,8 +3389,12 @@ size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, + seqCollector.maxSequences = outSeqsSize; + zc->seqCollector = seqCollector; + +- ZSTD_compress2(zc, dst, dstCapacity, src, srcSize); +- ZSTD_customFree(dst, ZSTD_defaultCMem); ++ { ++ const size_t ret = ZSTD_compress2(zc, dst, dstCapacity, src, srcSize); ++ ZSTD_customFree(dst, ZSTD_defaultCMem); ++ FORWARD_IF_ERROR(ret, "ZSTD_compress2 failed"); ++ } ++ assert(zc->seqCollector.seqIndex <= ZSTD_sequenceBound(srcSize)); + return zc->seqCollector.seqIndex; + } + +@@ -2910,19 +3423,17 @@ static int ZSTD_isRLE(const BYTE* src, size_t length) { + const size_t unrollMask = unrollSize - 1; + const size_t prefixLength = length & unrollMask; + size_t i; +- size_t u; + if (length == 1) return 1; + /* Check if prefix is RLE first before using unrolled loop */ + if (prefixLength && ZSTD_count(ip+1, ip, ip+prefixLength) != prefixLength-1) { + return 0; + } + for (i = prefixLength; i != length; i += unrollSize) { ++ size_t u; + for (u = 0; u < unrollSize; u += sizeof(size_t)) { + if (MEM_readST(ip + i + u) != valueST) { + return 0; +- } +- } +- } ++ } } } + return 1; + } + +@@ -2938,7 +3449,8 @@ static int ZSTD_maybeRLE(seqStore_t const* seqStore) + return nbSeqs < 4 && nbLits < 10; + } + +-static void ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* const bs) ++static void ++ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* const bs) + { + ZSTD_compressedBlockState_t* const tmp = bs->prevCBlock; + bs->prevCBlock = bs->nextCBlock; +@@ -2946,7 +3458,9 @@ static void ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* c + } + + /* Writes the block header */ +-static void writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastBlock) { ++static void ++writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastBlock) ++{ + U32 const cBlockHeader = cSize == 1 ? + lastBlock + (((U32)bt_rle)<<1) + (U32)(blockSize << 3) : + lastBlock + (((U32)bt_compressed)<<1) + (U32)(cSize << 3); +@@ -2959,13 +3473,16 @@ static void writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastB + * Stores literals block type (raw, rle, compressed, repeat) and + * huffman description table to hufMetadata. + * Requires ENTROPY_WORKSPACE_SIZE workspace +- * @return : size of huffman description table or error code */ +-static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSize, +- const ZSTD_hufCTables_t* prevHuf, +- ZSTD_hufCTables_t* nextHuf, +- ZSTD_hufCTablesMetadata_t* hufMetadata, +- const int literalsCompressionIsDisabled, +- void* workspace, size_t wkspSize) ++ * @return : size of huffman description table, or an error code ++ */ ++static size_t ++ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSize, ++ const ZSTD_hufCTables_t* prevHuf, ++ ZSTD_hufCTables_t* nextHuf, ++ ZSTD_hufCTablesMetadata_t* hufMetadata, ++ const int literalsCompressionIsDisabled, ++ void* workspace, size_t wkspSize, ++ int hufFlags) + { + BYTE* const wkspStart = (BYTE*)workspace; + BYTE* const wkspEnd = wkspStart + wkspSize; +@@ -2973,9 +3490,9 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi + unsigned* const countWksp = (unsigned*)workspace; + const size_t countWkspSize = (HUF_SYMBOLVALUE_MAX + 1) * sizeof(unsigned); + BYTE* const nodeWksp = countWkspStart + countWkspSize; +- const size_t nodeWkspSize = wkspEnd-nodeWksp; ++ const size_t nodeWkspSize = (size_t)(wkspEnd - nodeWksp); + unsigned maxSymbolValue = HUF_SYMBOLVALUE_MAX; +- unsigned huffLog = HUF_TABLELOG_DEFAULT; ++ unsigned huffLog = LitHufLog; + HUF_repeat repeat = prevHuf->repeatMode; + DEBUGLOG(5, "ZSTD_buildBlockEntropyStats_literals (srcSize=%zu)", srcSize); + +@@ -2990,73 +3507,77 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi + + /* small ? don't even attempt compression (speed opt) */ + #ifndef COMPRESS_LITERALS_SIZE_MIN +-#define COMPRESS_LITERALS_SIZE_MIN 63 ++# define COMPRESS_LITERALS_SIZE_MIN 63 /* heuristic */ + #endif + { size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN; + if (srcSize <= minLitSize) { + DEBUGLOG(5, "set_basic - too small"); + hufMetadata->hType = set_basic; + return 0; +- } +- } ++ } } + + /* Scan input and build symbol stats */ +- { size_t const largest = HIST_count_wksp (countWksp, &maxSymbolValue, (const BYTE*)src, srcSize, workspace, wkspSize); ++ { size_t const largest = ++ HIST_count_wksp (countWksp, &maxSymbolValue, ++ (const BYTE*)src, srcSize, ++ workspace, wkspSize); + FORWARD_IF_ERROR(largest, "HIST_count_wksp failed"); + if (largest == srcSize) { ++ /* only one literal symbol */ + DEBUGLOG(5, "set_rle"); + hufMetadata->hType = set_rle; + return 0; + } + if (largest <= (srcSize >> 7)+4) { ++ /* heuristic: likely not compressible */ + DEBUGLOG(5, "set_basic - no gain"); + hufMetadata->hType = set_basic; + return 0; +- } +- } ++ } } + + /* Validate the previous Huffman table */ +- if (repeat == HUF_repeat_check && !HUF_validateCTable((HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue)) { ++ if (repeat == HUF_repeat_check ++ && !HUF_validateCTable((HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue)) { + repeat = HUF_repeat_none; + } + + /* Build Huffman Tree */ + ZSTD_memset(nextHuf->CTable, 0, sizeof(nextHuf->CTable)); +- huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); ++ huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue, nodeWksp, nodeWkspSize, nextHuf->CTable, countWksp, hufFlags); ++ assert(huffLog <= LitHufLog); + { size_t const maxBits = HUF_buildCTable_wksp((HUF_CElt*)nextHuf->CTable, countWksp, + maxSymbolValue, huffLog, + nodeWksp, nodeWkspSize); + FORWARD_IF_ERROR(maxBits, "HUF_buildCTable_wksp"); + huffLog = (U32)maxBits; +- { /* Build and write the CTable */ +- size_t const newCSize = HUF_estimateCompressedSize( +- (HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue); +- size_t const hSize = HUF_writeCTable_wksp( +- hufMetadata->hufDesBuffer, sizeof(hufMetadata->hufDesBuffer), +- (HUF_CElt*)nextHuf->CTable, maxSymbolValue, huffLog, +- nodeWksp, nodeWkspSize); +- /* Check against repeating the previous CTable */ +- if (repeat != HUF_repeat_none) { +- size_t const oldCSize = HUF_estimateCompressedSize( +- (HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue); +- if (oldCSize < srcSize && (oldCSize <= hSize + newCSize || hSize + 12 >= srcSize)) { +- DEBUGLOG(5, "set_repeat - smaller"); +- ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); +- hufMetadata->hType = set_repeat; +- return 0; +- } +- } +- if (newCSize + hSize >= srcSize) { +- DEBUGLOG(5, "set_basic - no gains"); ++ } ++ { /* Build and write the CTable */ ++ size_t const newCSize = HUF_estimateCompressedSize( ++ (HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue); ++ size_t const hSize = HUF_writeCTable_wksp( ++ hufMetadata->hufDesBuffer, sizeof(hufMetadata->hufDesBuffer), ++ (HUF_CElt*)nextHuf->CTable, maxSymbolValue, huffLog, ++ nodeWksp, nodeWkspSize); ++ /* Check against repeating the previous CTable */ ++ if (repeat != HUF_repeat_none) { ++ size_t const oldCSize = HUF_estimateCompressedSize( ++ (HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue); ++ if (oldCSize < srcSize && (oldCSize <= hSize + newCSize || hSize + 12 >= srcSize)) { ++ DEBUGLOG(5, "set_repeat - smaller"); + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); +- hufMetadata->hType = set_basic; ++ hufMetadata->hType = set_repeat; + return 0; +- } +- DEBUGLOG(5, "set_compressed (hSize=%u)", (U32)hSize); +- hufMetadata->hType = set_compressed; +- nextHuf->repeatMode = HUF_repeat_check; +- return hSize; ++ } } ++ if (newCSize + hSize >= srcSize) { ++ DEBUGLOG(5, "set_basic - no gains"); ++ ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); ++ hufMetadata->hType = set_basic; ++ return 0; + } ++ DEBUGLOG(5, "set_compressed (hSize=%u)", (U32)hSize); ++ hufMetadata->hType = set_compressed; ++ nextHuf->repeatMode = HUF_repeat_check; ++ return hSize; + } + } + +@@ -3066,8 +3587,9 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi + * and updates nextEntropy to the appropriate repeatMode. + */ + static ZSTD_symbolEncodingTypeStats_t +-ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) { +- ZSTD_symbolEncodingTypeStats_t stats = {set_basic, set_basic, set_basic, 0, 0}; ++ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) ++{ ++ ZSTD_symbolEncodingTypeStats_t stats = {set_basic, set_basic, set_basic, 0, 0, 0}; + nextEntropy->litlength_repeatMode = FSE_repeat_none; + nextEntropy->offcode_repeatMode = FSE_repeat_none; + nextEntropy->matchlength_repeatMode = FSE_repeat_none; +@@ -3078,16 +3600,18 @@ ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) { + * Builds entropy for the sequences. + * Stores symbol compression modes and fse table to fseMetadata. + * Requires ENTROPY_WORKSPACE_SIZE wksp. +- * @return : size of fse tables or error code */ +-static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, +- const ZSTD_fseCTables_t* prevEntropy, +- ZSTD_fseCTables_t* nextEntropy, +- const ZSTD_CCtx_params* cctxParams, +- ZSTD_fseCTablesMetadata_t* fseMetadata, +- void* workspace, size_t wkspSize) ++ * @return : size of fse tables or error code */ ++static size_t ++ZSTD_buildBlockEntropyStats_sequences( ++ const seqStore_t* seqStorePtr, ++ const ZSTD_fseCTables_t* prevEntropy, ++ ZSTD_fseCTables_t* nextEntropy, ++ const ZSTD_CCtx_params* cctxParams, ++ ZSTD_fseCTablesMetadata_t* fseMetadata, ++ void* workspace, size_t wkspSize) + { + ZSTD_strategy const strategy = cctxParams->cParams.strategy; +- size_t const nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; ++ size_t const nbSeq = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + BYTE* const ostart = fseMetadata->fseTablesBuffer; + BYTE* const oend = ostart + sizeof(fseMetadata->fseTablesBuffer); + BYTE* op = ostart; +@@ -3114,23 +3638,28 @@ static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, + /* ZSTD_buildBlockEntropyStats() : + * Builds entropy for the block. + * Requires workspace size ENTROPY_WORKSPACE_SIZE +- * +- * @return : 0 on success or error code ++ * @return : 0 on success, or an error code ++ * Note : also employed in superblock + */ +-size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, +- const ZSTD_entropyCTables_t* prevEntropy, +- ZSTD_entropyCTables_t* nextEntropy, +- const ZSTD_CCtx_params* cctxParams, +- ZSTD_entropyCTablesMetadata_t* entropyMetadata, +- void* workspace, size_t wkspSize) +-{ +- size_t const litSize = seqStorePtr->lit - seqStorePtr->litStart; ++size_t ZSTD_buildBlockEntropyStats( ++ const seqStore_t* seqStorePtr, ++ const ZSTD_entropyCTables_t* prevEntropy, ++ ZSTD_entropyCTables_t* nextEntropy, ++ const ZSTD_CCtx_params* cctxParams, ++ ZSTD_entropyCTablesMetadata_t* entropyMetadata, ++ void* workspace, size_t wkspSize) ++{ ++ size_t const litSize = (size_t)(seqStorePtr->lit - seqStorePtr->litStart); ++ int const huf_useOptDepth = (cctxParams->cParams.strategy >= HUF_OPTIMAL_DEPTH_THRESHOLD); ++ int const hufFlags = huf_useOptDepth ? HUF_flags_optimalDepth : 0; ++ + entropyMetadata->hufMetadata.hufDesSize = + ZSTD_buildBlockEntropyStats_literals(seqStorePtr->litStart, litSize, + &prevEntropy->huf, &nextEntropy->huf, + &entropyMetadata->hufMetadata, + ZSTD_literalsCompressionIsDisabled(cctxParams), +- workspace, wkspSize); ++ workspace, wkspSize, hufFlags); ++ + FORWARD_IF_ERROR(entropyMetadata->hufMetadata.hufDesSize, "ZSTD_buildBlockEntropyStats_literals failed"); + entropyMetadata->fseMetadata.fseTablesSize = + ZSTD_buildBlockEntropyStats_sequences(seqStorePtr, +@@ -3143,11 +3672,12 @@ size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, + } + + /* Returns the size estimate for the literals section (header + content) of a block */ +-static size_t ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSize, +- const ZSTD_hufCTables_t* huf, +- const ZSTD_hufCTablesMetadata_t* hufMetadata, +- void* workspace, size_t wkspSize, +- int writeEntropy) ++static size_t ++ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSize, ++ const ZSTD_hufCTables_t* huf, ++ const ZSTD_hufCTablesMetadata_t* hufMetadata, ++ void* workspace, size_t wkspSize, ++ int writeEntropy) + { + unsigned* const countWksp = (unsigned*)workspace; + unsigned maxSymbolValue = HUF_SYMBOLVALUE_MAX; +@@ -3169,12 +3699,13 @@ static size_t ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSiz + } + + /* Returns the size estimate for the FSE-compressed symbols (of, ml, ll) of a block */ +-static size_t ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, +- const BYTE* codeTable, size_t nbSeq, unsigned maxCode, +- const FSE_CTable* fseCTable, +- const U8* additionalBits, +- short const* defaultNorm, U32 defaultNormLog, U32 defaultMax, +- void* workspace, size_t wkspSize) ++static size_t ++ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, ++ const BYTE* codeTable, size_t nbSeq, unsigned maxCode, ++ const FSE_CTable* fseCTable, ++ const U8* additionalBits, ++ short const* defaultNorm, U32 defaultNormLog, U32 defaultMax, ++ void* workspace, size_t wkspSize) + { + unsigned* const countWksp = (unsigned*)workspace; + const BYTE* ctp = codeTable; +@@ -3206,99 +3737,107 @@ static size_t ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, + } + + /* Returns the size estimate for the sequences section (header + content) of a block */ +-static size_t ZSTD_estimateBlockSize_sequences(const BYTE* ofCodeTable, +- const BYTE* llCodeTable, +- const BYTE* mlCodeTable, +- size_t nbSeq, +- const ZSTD_fseCTables_t* fseTables, +- const ZSTD_fseCTablesMetadata_t* fseMetadata, +- void* workspace, size_t wkspSize, +- int writeEntropy) ++static size_t ++ZSTD_estimateBlockSize_sequences(const BYTE* ofCodeTable, ++ const BYTE* llCodeTable, ++ const BYTE* mlCodeTable, ++ size_t nbSeq, ++ const ZSTD_fseCTables_t* fseTables, ++ const ZSTD_fseCTablesMetadata_t* fseMetadata, ++ void* workspace, size_t wkspSize, ++ int writeEntropy) + { + size_t sequencesSectionHeaderSize = 1 /* seqHead */ + 1 /* min seqSize size */ + (nbSeq >= 128) + (nbSeq >= LONGNBSEQ); + size_t cSeqSizeEstimate = 0; + cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->ofType, ofCodeTable, nbSeq, MaxOff, +- fseTables->offcodeCTable, NULL, +- OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, +- workspace, wkspSize); ++ fseTables->offcodeCTable, NULL, ++ OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, ++ workspace, wkspSize); + cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->llType, llCodeTable, nbSeq, MaxLL, +- fseTables->litlengthCTable, LL_bits, +- LL_defaultNorm, LL_defaultNormLog, MaxLL, +- workspace, wkspSize); ++ fseTables->litlengthCTable, LL_bits, ++ LL_defaultNorm, LL_defaultNormLog, MaxLL, ++ workspace, wkspSize); + cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->mlType, mlCodeTable, nbSeq, MaxML, +- fseTables->matchlengthCTable, ML_bits, +- ML_defaultNorm, ML_defaultNormLog, MaxML, +- workspace, wkspSize); ++ fseTables->matchlengthCTable, ML_bits, ++ ML_defaultNorm, ML_defaultNormLog, MaxML, ++ workspace, wkspSize); + if (writeEntropy) cSeqSizeEstimate += fseMetadata->fseTablesSize; + return cSeqSizeEstimate + sequencesSectionHeaderSize; + } + + /* Returns the size estimate for a given stream of literals, of, ll, ml */ +-static size_t ZSTD_estimateBlockSize(const BYTE* literals, size_t litSize, +- const BYTE* ofCodeTable, +- const BYTE* llCodeTable, +- const BYTE* mlCodeTable, +- size_t nbSeq, +- const ZSTD_entropyCTables_t* entropy, +- const ZSTD_entropyCTablesMetadata_t* entropyMetadata, +- void* workspace, size_t wkspSize, +- int writeLitEntropy, int writeSeqEntropy) { ++static size_t ++ZSTD_estimateBlockSize(const BYTE* literals, size_t litSize, ++ const BYTE* ofCodeTable, ++ const BYTE* llCodeTable, ++ const BYTE* mlCodeTable, ++ size_t nbSeq, ++ const ZSTD_entropyCTables_t* entropy, ++ const ZSTD_entropyCTablesMetadata_t* entropyMetadata, ++ void* workspace, size_t wkspSize, ++ int writeLitEntropy, int writeSeqEntropy) ++{ + size_t const literalsSize = ZSTD_estimateBlockSize_literal(literals, litSize, +- &entropy->huf, &entropyMetadata->hufMetadata, +- workspace, wkspSize, writeLitEntropy); ++ &entropy->huf, &entropyMetadata->hufMetadata, ++ workspace, wkspSize, writeLitEntropy); + size_t const seqSize = ZSTD_estimateBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, +- nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, +- workspace, wkspSize, writeSeqEntropy); ++ nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, ++ workspace, wkspSize, writeSeqEntropy); + return seqSize + literalsSize + ZSTD_blockHeaderSize; + } + + /* Builds entropy statistics and uses them for blocksize estimation. + * +- * Returns the estimated compressed size of the seqStore, or a zstd error. ++ * @return: estimated compressed size of the seqStore, or a zstd error. + */ +-static size_t ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(seqStore_t* seqStore, ZSTD_CCtx* zc) { +- ZSTD_entropyCTablesMetadata_t* entropyMetadata = &zc->blockSplitCtx.entropyMetadata; ++static size_t ++ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(seqStore_t* seqStore, ZSTD_CCtx* zc) ++{ ++ ZSTD_entropyCTablesMetadata_t* const entropyMetadata = &zc->blockSplitCtx.entropyMetadata; + DEBUGLOG(6, "ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize()"); + FORWARD_IF_ERROR(ZSTD_buildBlockEntropyStats(seqStore, + &zc->blockState.prevCBlock->entropy, + &zc->blockState.nextCBlock->entropy, + &zc->appliedParams, + entropyMetadata, +- zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */), ""); +- return ZSTD_estimateBlockSize(seqStore->litStart, (size_t)(seqStore->lit - seqStore->litStart), ++ zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE), ""); ++ return ZSTD_estimateBlockSize( ++ seqStore->litStart, (size_t)(seqStore->lit - seqStore->litStart), + seqStore->ofCode, seqStore->llCode, seqStore->mlCode, + (size_t)(seqStore->sequences - seqStore->sequencesStart), +- &zc->blockState.nextCBlock->entropy, entropyMetadata, zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE, ++ &zc->blockState.nextCBlock->entropy, ++ entropyMetadata, ++ zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE, + (int)(entropyMetadata->hufMetadata.hType == set_compressed), 1); + } + + /* Returns literals bytes represented in a seqStore */ +-static size_t ZSTD_countSeqStoreLiteralsBytes(const seqStore_t* const seqStore) { ++static size_t ZSTD_countSeqStoreLiteralsBytes(const seqStore_t* const seqStore) ++{ + size_t literalsBytes = 0; +- size_t const nbSeqs = seqStore->sequences - seqStore->sequencesStart; ++ size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); + size_t i; + for (i = 0; i < nbSeqs; ++i) { +- seqDef seq = seqStore->sequencesStart[i]; ++ seqDef const seq = seqStore->sequencesStart[i]; + literalsBytes += seq.litLength; + if (i == seqStore->longLengthPos && seqStore->longLengthType == ZSTD_llt_literalLength) { + literalsBytes += 0x10000; +- } +- } ++ } } + return literalsBytes; + } + + /* Returns match bytes represented in a seqStore */ +-static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) { ++static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) ++{ + size_t matchBytes = 0; +- size_t const nbSeqs = seqStore->sequences - seqStore->sequencesStart; ++ size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); + size_t i; + for (i = 0; i < nbSeqs; ++i) { + seqDef seq = seqStore->sequencesStart[i]; + matchBytes += seq.mlBase + MINMATCH; + if (i == seqStore->longLengthPos && seqStore->longLengthType == ZSTD_llt_matchLength) { + matchBytes += 0x10000; +- } +- } ++ } } + return matchBytes; + } + +@@ -3307,15 +3846,12 @@ static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) { + */ + static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, + const seqStore_t* originalSeqStore, +- size_t startIdx, size_t endIdx) { +- BYTE* const litEnd = originalSeqStore->lit; +- size_t literalsBytes; +- size_t literalsBytesPreceding = 0; +- ++ size_t startIdx, size_t endIdx) ++{ + *resultSeqStore = *originalSeqStore; + if (startIdx > 0) { + resultSeqStore->sequences = originalSeqStore->sequencesStart + startIdx; +- literalsBytesPreceding = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); ++ resultSeqStore->litStart += ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); + } + + /* Move longLengthPos into the correct position if necessary */ +@@ -3328,13 +3864,12 @@ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, + } + resultSeqStore->sequencesStart = originalSeqStore->sequencesStart + startIdx; + resultSeqStore->sequences = originalSeqStore->sequencesStart + endIdx; +- literalsBytes = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); +- resultSeqStore->litStart += literalsBytesPreceding; + if (endIdx == (size_t)(originalSeqStore->sequences - originalSeqStore->sequencesStart)) { + /* This accounts for possible last literals if the derived chunk reaches the end of the block */ +- resultSeqStore->lit = litEnd; ++ assert(resultSeqStore->lit == originalSeqStore->lit); + } else { +- resultSeqStore->lit = resultSeqStore->litStart+literalsBytes; ++ size_t const literalsBytes = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); ++ resultSeqStore->lit = resultSeqStore->litStart + literalsBytes; + } + resultSeqStore->llCode += startIdx; + resultSeqStore->mlCode += startIdx; +@@ -3342,20 +3877,26 @@ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, + } + + /* +- * Returns the raw offset represented by the combination of offCode, ll0, and repcode history. +- * offCode must represent a repcode in the numeric representation of ZSTD_storeSeq(). ++ * Returns the raw offset represented by the combination of offBase, ll0, and repcode history. ++ * offBase must represent a repcode in the numeric representation of ZSTD_storeSeq(). + */ + static U32 +-ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offCode, const U32 ll0) +-{ +- U32 const adjustedOffCode = STORED_REPCODE(offCode) - 1 + ll0; /* [ 0 - 3 ] */ +- assert(STORED_IS_REPCODE(offCode)); +- if (adjustedOffCode == ZSTD_REP_NUM) { +- /* litlength == 0 and offCode == 2 implies selection of first repcode - 1 */ +- assert(rep[0] > 0); ++ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offBase, const U32 ll0) ++{ ++ U32 const adjustedRepCode = OFFBASE_TO_REPCODE(offBase) - 1 + ll0; /* [ 0 - 3 ] */ ++ assert(OFFBASE_IS_REPCODE(offBase)); ++ if (adjustedRepCode == ZSTD_REP_NUM) { ++ assert(ll0); ++ /* litlength == 0 and offCode == 2 implies selection of first repcode - 1 ++ * This is only valid if it results in a valid offset value, aka > 0. ++ * Note : it may happen that `rep[0]==1` in exceptional circumstances. ++ * In which case this function will return 0, which is an invalid offset. ++ * It's not an issue though, since this value will be ++ * compared and discarded within ZSTD_seqStore_resolveOffCodes(). ++ */ + return rep[0] - 1; + } +- return rep[adjustedOffCode]; ++ return rep[adjustedRepCode]; + } + + /* +@@ -3371,30 +3912,33 @@ ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offCode, c + * 1-3 : repcode 1-3 + * 4+ : real_offset+3 + */ +-static void ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_t* const cRepcodes, +- seqStore_t* const seqStore, U32 const nbSeq) { ++static void ++ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_t* const cRepcodes, ++ const seqStore_t* const seqStore, U32 const nbSeq) ++{ + U32 idx = 0; ++ U32 const longLitLenIdx = seqStore->longLengthType == ZSTD_llt_literalLength ? seqStore->longLengthPos : nbSeq; + for (; idx < nbSeq; ++idx) { + seqDef* const seq = seqStore->sequencesStart + idx; +- U32 const ll0 = (seq->litLength == 0); +- U32 const offCode = OFFBASE_TO_STORED(seq->offBase); +- assert(seq->offBase > 0); +- if (STORED_IS_REPCODE(offCode)) { +- U32 const dRawOffset = ZSTD_resolveRepcodeToRawOffset(dRepcodes->rep, offCode, ll0); +- U32 const cRawOffset = ZSTD_resolveRepcodeToRawOffset(cRepcodes->rep, offCode, ll0); ++ U32 const ll0 = (seq->litLength == 0) && (idx != longLitLenIdx); ++ U32 const offBase = seq->offBase; ++ assert(offBase > 0); ++ if (OFFBASE_IS_REPCODE(offBase)) { ++ U32 const dRawOffset = ZSTD_resolveRepcodeToRawOffset(dRepcodes->rep, offBase, ll0); ++ U32 const cRawOffset = ZSTD_resolveRepcodeToRawOffset(cRepcodes->rep, offBase, ll0); + /* Adjust simulated decompression repcode history if we come across a mismatch. Replace + * the repcode with the offset it actually references, determined by the compression + * repcode history. + */ + if (dRawOffset != cRawOffset) { +- seq->offBase = cRawOffset + ZSTD_REP_NUM; ++ seq->offBase = OFFSET_TO_OFFBASE(cRawOffset); + } + } + /* Compression repcode history is always updated with values directly from the unmodified seqStore. + * Decompression repcode history may use modified seq->offset value taken from compression repcode history. + */ +- ZSTD_updateRep(dRepcodes->rep, OFFBASE_TO_STORED(seq->offBase), ll0); +- ZSTD_updateRep(cRepcodes->rep, offCode, ll0); ++ ZSTD_updateRep(dRepcodes->rep, seq->offBase, ll0); ++ ZSTD_updateRep(cRepcodes->rep, offBase, ll0); + } + } + +@@ -3404,10 +3948,11 @@ static void ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_ + * Returns the total size of that block (including header) or a ZSTD error code. + */ + static size_t +-ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, ++ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, ++ const seqStore_t* const seqStore, + repcodes_t* const dRep, repcodes_t* const cRep, + void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, ++ const void* src, size_t srcSize, + U32 lastBlock, U32 isPartition) + { + const U32 rleMaxLength = 25; +@@ -3442,8 +3987,9 @@ ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, + cSeqsSize = 1; + } + ++ /* Sequence collection not supported when block splitting */ + if (zc->seqCollector.collectSequences) { +- ZSTD_copyBlockSequences(zc); ++ FORWARD_IF_ERROR(ZSTD_copyBlockSequences(&zc->seqCollector, seqStore, dRepOriginal.rep), "copyBlockSequences failed"); + ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); + return 0; + } +@@ -3481,45 +4027,49 @@ typedef struct { + + /* Helper function to perform the recursive search for block splits. + * Estimates the cost of seqStore prior to split, and estimates the cost of splitting the sequences in half. +- * If advantageous to split, then we recurse down the two sub-blocks. If not, or if an error occurred in estimation, then +- * we do not recurse. ++ * If advantageous to split, then we recurse down the two sub-blocks. ++ * If not, or if an error occurred in estimation, then we do not recurse. + * +- * Note: The recursion depth is capped by a heuristic minimum number of sequences, defined by MIN_SEQUENCES_BLOCK_SPLITTING. ++ * Note: The recursion depth is capped by a heuristic minimum number of sequences, ++ * defined by MIN_SEQUENCES_BLOCK_SPLITTING. + * In theory, this means the absolute largest recursion depth is 10 == log2(maxNbSeqInBlock/MIN_SEQUENCES_BLOCK_SPLITTING). + * In practice, recursion depth usually doesn't go beyond 4. + * +- * Furthermore, the number of splits is capped by ZSTD_MAX_NB_BLOCK_SPLITS. At ZSTD_MAX_NB_BLOCK_SPLITS == 196 with the current existing blockSize ++ * Furthermore, the number of splits is capped by ZSTD_MAX_NB_BLOCK_SPLITS. ++ * At ZSTD_MAX_NB_BLOCK_SPLITS == 196 with the current existing blockSize + * maximum of 128 KB, this value is actually impossible to reach. + */ + static void + ZSTD_deriveBlockSplitsHelper(seqStoreSplits* splits, size_t startIdx, size_t endIdx, + ZSTD_CCtx* zc, const seqStore_t* origSeqStore) + { +- seqStore_t* fullSeqStoreChunk = &zc->blockSplitCtx.fullSeqStoreChunk; +- seqStore_t* firstHalfSeqStore = &zc->blockSplitCtx.firstHalfSeqStore; +- seqStore_t* secondHalfSeqStore = &zc->blockSplitCtx.secondHalfSeqStore; ++ seqStore_t* const fullSeqStoreChunk = &zc->blockSplitCtx.fullSeqStoreChunk; ++ seqStore_t* const firstHalfSeqStore = &zc->blockSplitCtx.firstHalfSeqStore; ++ seqStore_t* const secondHalfSeqStore = &zc->blockSplitCtx.secondHalfSeqStore; + size_t estimatedOriginalSize; + size_t estimatedFirstHalfSize; + size_t estimatedSecondHalfSize; + size_t midIdx = (startIdx + endIdx)/2; + ++ DEBUGLOG(5, "ZSTD_deriveBlockSplitsHelper: startIdx=%zu endIdx=%zu", startIdx, endIdx); ++ assert(endIdx >= startIdx); + if (endIdx - startIdx < MIN_SEQUENCES_BLOCK_SPLITTING || splits->idx >= ZSTD_MAX_NB_BLOCK_SPLITS) { +- DEBUGLOG(6, "ZSTD_deriveBlockSplitsHelper: Too few sequences"); ++ DEBUGLOG(6, "ZSTD_deriveBlockSplitsHelper: Too few sequences (%zu)", endIdx - startIdx); + return; + } +- DEBUGLOG(4, "ZSTD_deriveBlockSplitsHelper: startIdx=%zu endIdx=%zu", startIdx, endIdx); + ZSTD_deriveSeqStoreChunk(fullSeqStoreChunk, origSeqStore, startIdx, endIdx); + ZSTD_deriveSeqStoreChunk(firstHalfSeqStore, origSeqStore, startIdx, midIdx); + ZSTD_deriveSeqStoreChunk(secondHalfSeqStore, origSeqStore, midIdx, endIdx); + estimatedOriginalSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(fullSeqStoreChunk, zc); + estimatedFirstHalfSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(firstHalfSeqStore, zc); + estimatedSecondHalfSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(secondHalfSeqStore, zc); +- DEBUGLOG(4, "Estimated original block size: %zu -- First half split: %zu -- Second half split: %zu", ++ DEBUGLOG(5, "Estimated original block size: %zu -- First half split: %zu -- Second half split: %zu", + estimatedOriginalSize, estimatedFirstHalfSize, estimatedSecondHalfSize); + if (ZSTD_isError(estimatedOriginalSize) || ZSTD_isError(estimatedFirstHalfSize) || ZSTD_isError(estimatedSecondHalfSize)) { + return; + } + if (estimatedFirstHalfSize + estimatedSecondHalfSize < estimatedOriginalSize) { ++ DEBUGLOG(5, "split decided at seqNb:%zu", midIdx); + ZSTD_deriveBlockSplitsHelper(splits, startIdx, midIdx, zc, origSeqStore); + splits->splitLocations[splits->idx] = (U32)midIdx; + splits->idx++; +@@ -3527,14 +4077,18 @@ ZSTD_deriveBlockSplitsHelper(seqStoreSplits* splits, size_t startIdx, size_t end + } + } + +-/* Base recursive function. Populates a table with intra-block partition indices that can improve compression ratio. ++/* Base recursive function. ++ * Populates a table with intra-block partition indices that can improve compression ratio. + * +- * Returns the number of splits made (which equals the size of the partition table - 1). ++ * @return: number of splits made (which equals the size of the partition table - 1). + */ +-static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) { +- seqStoreSplits splits = {partitions, 0}; ++static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) ++{ ++ seqStoreSplits splits; ++ splits.splitLocations = partitions; ++ splits.idx = 0; + if (nbSeq <= 4) { +- DEBUGLOG(4, "ZSTD_deriveBlockSplits: Too few sequences to split"); ++ DEBUGLOG(5, "ZSTD_deriveBlockSplits: Too few sequences to split (%u <= 4)", nbSeq); + /* Refuse to try and split anything with less than 4 sequences */ + return 0; + } +@@ -3550,18 +4104,20 @@ static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) + * Returns combined size of all blocks (which includes headers), or a ZSTD error code. + */ + static size_t +-ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, +- const void* src, size_t blockSize, U32 lastBlock, U32 nbSeq) ++ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t blockSize, ++ U32 lastBlock, U32 nbSeq) + { + size_t cSize = 0; + const BYTE* ip = (const BYTE*)src; + BYTE* op = (BYTE*)dst; + size_t i = 0; + size_t srcBytesTotal = 0; +- U32* partitions = zc->blockSplitCtx.partitions; /* size == ZSTD_MAX_NB_BLOCK_SPLITS */ +- seqStore_t* nextSeqStore = &zc->blockSplitCtx.nextSeqStore; +- seqStore_t* currSeqStore = &zc->blockSplitCtx.currSeqStore; +- size_t numSplits = ZSTD_deriveBlockSplits(zc, partitions, nbSeq); ++ U32* const partitions = zc->blockSplitCtx.partitions; /* size == ZSTD_MAX_NB_BLOCK_SPLITS */ ++ seqStore_t* const nextSeqStore = &zc->blockSplitCtx.nextSeqStore; ++ seqStore_t* const currSeqStore = &zc->blockSplitCtx.currSeqStore; ++ size_t const numSplits = ZSTD_deriveBlockSplits(zc, partitions, nbSeq); + + /* If a block is split and some partitions are emitted as RLE/uncompressed, then repcode history + * may become invalid. In order to reconcile potentially invalid repcodes, we keep track of two +@@ -3583,30 +4139,31 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac + ZSTD_memcpy(cRep.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); + ZSTD_memset(nextSeqStore, 0, sizeof(seqStore_t)); + +- DEBUGLOG(4, "ZSTD_compressBlock_splitBlock_internal (dstCapacity=%u, dictLimit=%u, nextToUpdate=%u)", ++ DEBUGLOG(5, "ZSTD_compressBlock_splitBlock_internal (dstCapacity=%u, dictLimit=%u, nextToUpdate=%u)", + (unsigned)dstCapacity, (unsigned)zc->blockState.matchState.window.dictLimit, + (unsigned)zc->blockState.matchState.nextToUpdate); + + if (numSplits == 0) { +- size_t cSizeSingleBlock = ZSTD_compressSeqStore_singleBlock(zc, &zc->seqStore, +- &dRep, &cRep, +- op, dstCapacity, +- ip, blockSize, +- lastBlock, 0 /* isPartition */); ++ size_t cSizeSingleBlock = ++ ZSTD_compressSeqStore_singleBlock(zc, &zc->seqStore, ++ &dRep, &cRep, ++ op, dstCapacity, ++ ip, blockSize, ++ lastBlock, 0 /* isPartition */); + FORWARD_IF_ERROR(cSizeSingleBlock, "Compressing single block from splitBlock_internal() failed!"); + DEBUGLOG(5, "ZSTD_compressBlock_splitBlock_internal: No splits"); +- assert(cSizeSingleBlock <= ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize); ++ assert(zc->blockSize <= ZSTD_BLOCKSIZE_MAX); ++ assert(cSizeSingleBlock <= zc->blockSize + ZSTD_blockHeaderSize); + return cSizeSingleBlock; + } + + ZSTD_deriveSeqStoreChunk(currSeqStore, &zc->seqStore, 0, partitions[0]); + for (i = 0; i <= numSplits; ++i) { +- size_t srcBytes; + size_t cSizeChunk; + U32 const lastPartition = (i == numSplits); + U32 lastBlockEntireSrc = 0; + +- srcBytes = ZSTD_countSeqStoreLiteralsBytes(currSeqStore) + ZSTD_countSeqStoreMatchBytes(currSeqStore); ++ size_t srcBytes = ZSTD_countSeqStoreLiteralsBytes(currSeqStore) + ZSTD_countSeqStoreMatchBytes(currSeqStore); + srcBytesTotal += srcBytes; + if (lastPartition) { + /* This is the final partition, need to account for possible last literals */ +@@ -3621,7 +4178,8 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac + op, dstCapacity, + ip, srcBytes, + lastBlockEntireSrc, 1 /* isPartition */); +- DEBUGLOG(5, "Estimated size: %zu actual size: %zu", ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(currSeqStore, zc), cSizeChunk); ++ DEBUGLOG(5, "Estimated size: %zu vs %zu : actual size", ++ ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(currSeqStore, zc), cSizeChunk); + FORWARD_IF_ERROR(cSizeChunk, "Compressing chunk failed!"); + + ip += srcBytes; +@@ -3629,10 +4187,10 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac + dstCapacity -= cSizeChunk; + cSize += cSizeChunk; + *currSeqStore = *nextSeqStore; +- assert(cSizeChunk <= ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize); ++ assert(cSizeChunk <= zc->blockSize + ZSTD_blockHeaderSize); + } +- /* cRep and dRep may have diverged during the compression. If so, we use the dRep repcodes +- * for the next block. ++ /* cRep and dRep may have diverged during the compression. ++ * If so, we use the dRep repcodes for the next block. + */ + ZSTD_memcpy(zc->blockState.prevCBlock->rep, dRep.rep, sizeof(repcodes_t)); + return cSize; +@@ -3643,8 +4201,6 @@ ZSTD_compressBlock_splitBlock(ZSTD_CCtx* zc, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, U32 lastBlock) + { +- const BYTE* ip = (const BYTE*)src; +- BYTE* op = (BYTE*)dst; + U32 nbSeq; + size_t cSize; + DEBUGLOG(4, "ZSTD_compressBlock_splitBlock"); +@@ -3655,7 +4211,8 @@ ZSTD_compressBlock_splitBlock(ZSTD_CCtx* zc, + if (bss == ZSTDbss_noCompress) { + if (zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode == FSE_repeat_valid) + zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode = FSE_repeat_check; +- cSize = ZSTD_noCompressBlock(op, dstCapacity, ip, srcSize, lastBlock); ++ RETURN_ERROR_IF(zc->seqCollector.collectSequences, sequenceProducer_failed, "Uncompressible block"); ++ cSize = ZSTD_noCompressBlock(dst, dstCapacity, src, srcSize, lastBlock); + FORWARD_IF_ERROR(cSize, "ZSTD_noCompressBlock failed"); + DEBUGLOG(4, "ZSTD_compressBlock_splitBlock: Nocompress block"); + return cSize; +@@ -3673,9 +4230,9 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, U32 frame) + { +- /* This the upper bound for the length of an rle block. +- * This isn't the actual upper bound. Finding the real threshold +- * needs further investigation. ++ /* This is an estimated upper bound for the length of an rle block. ++ * This isn't the actual upper bound. ++ * Finding the real threshold needs further investigation. + */ + const U32 rleMaxLength = 25; + size_t cSize; +@@ -3687,11 +4244,15 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, + + { const size_t bss = ZSTD_buildSeqStore(zc, src, srcSize); + FORWARD_IF_ERROR(bss, "ZSTD_buildSeqStore failed"); +- if (bss == ZSTDbss_noCompress) { cSize = 0; goto out; } ++ if (bss == ZSTDbss_noCompress) { ++ RETURN_ERROR_IF(zc->seqCollector.collectSequences, sequenceProducer_failed, "Uncompressible block"); ++ cSize = 0; ++ goto out; ++ } + } + + if (zc->seqCollector.collectSequences) { +- ZSTD_copyBlockSequences(zc); ++ FORWARD_IF_ERROR(ZSTD_copyBlockSequences(&zc->seqCollector, ZSTD_getSeqStore(zc), zc->blockState.prevCBlock->rep), "copyBlockSequences failed"); + ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); + return 0; + } +@@ -3767,10 +4328,11 @@ static size_t ZSTD_compressBlock_targetCBlockSize_body(ZSTD_CCtx* zc, + * * cSize >= blockBound(srcSize): We have expanded the block too much so + * emit an uncompressed block. + */ +- { +- size_t const cSize = ZSTD_compressSuperBlock(zc, dst, dstCapacity, src, srcSize, lastBlock); ++ { size_t const cSize = ++ ZSTD_compressSuperBlock(zc, dst, dstCapacity, src, srcSize, lastBlock); + if (cSize != ERROR(dstSize_tooSmall)) { +- size_t const maxCSize = srcSize - ZSTD_minGain(srcSize, zc->appliedParams.cParams.strategy); ++ size_t const maxCSize = ++ srcSize - ZSTD_minGain(srcSize, zc->appliedParams.cParams.strategy); + FORWARD_IF_ERROR(cSize, "ZSTD_compressSuperBlock failed"); + if (cSize != 0 && cSize < maxCSize + ZSTD_blockHeaderSize) { + ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); +@@ -3778,7 +4340,7 @@ static size_t ZSTD_compressBlock_targetCBlockSize_body(ZSTD_CCtx* zc, + } + } + } +- } ++ } /* if (bss == ZSTDbss_compress)*/ + + DEBUGLOG(6, "Resorting to ZSTD_noCompressBlock()"); + /* Superblock compression failed, attempt to emit a single no compress block. +@@ -3836,7 +4398,7 @@ static void ZSTD_overflowCorrectIfNeeded(ZSTD_matchState_t* ms, + * All blocks will be terminated, all input will be consumed. + * Function will issue an error if there is not enough `dstCapacity` to hold the compressed content. + * Frame is supposed already started (header already produced) +-* @return : compressed size, or an error code ++* @return : compressed size, or an error code + */ + static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, +@@ -3860,7 +4422,9 @@ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, + ZSTD_matchState_t* const ms = &cctx->blockState.matchState; + U32 const lastBlock = lastFrameChunk & (blockSize >= remaining); + +- RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE, ++ /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding ++ * additional 1. We need to revisit and change this logic to be more consistent */ ++ RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE + 1, + dstSize_tooSmall, + "not enough space to store compressed block"); + if (remaining < blockSize) blockSize = remaining; +@@ -3899,7 +4463,7 @@ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, + MEM_writeLE24(op, cBlockHeader); + cSize += ZSTD_blockHeaderSize; + } +- } ++ } /* if (ZSTD_useTargetCBlockSize(&cctx->appliedParams))*/ + + + ip += blockSize; +@@ -4001,19 +4565,15 @@ size_t ZSTD_writeLastEmptyBlock(void* dst, size_t dstCapacity) + } + } + +-size_t ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq) ++void ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq) + { +- RETURN_ERROR_IF(cctx->stage != ZSTDcs_init, stage_wrong, +- "wrong cctx stage"); +- RETURN_ERROR_IF(cctx->appliedParams.ldmParams.enableLdm == ZSTD_ps_enable, +- parameter_unsupported, +- "incompatible with ldm"); ++ assert(cctx->stage == ZSTDcs_init); ++ assert(nbSeq == 0 || cctx->appliedParams.ldmParams.enableLdm != ZSTD_ps_enable); + cctx->externSeqStore.seq = seq; + cctx->externSeqStore.size = nbSeq; + cctx->externSeqStore.capacity = nbSeq; + cctx->externSeqStore.pos = 0; + cctx->externSeqStore.posInSequence = 0; +- return 0; + } + + +@@ -4078,31 +4638,51 @@ static size_t ZSTD_compressContinue_internal (ZSTD_CCtx* cctx, + } + } + +-size_t ZSTD_compressContinue (ZSTD_CCtx* cctx, +- void* dst, size_t dstCapacity, +- const void* src, size_t srcSize) ++size_t ZSTD_compressContinue_public(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize) + { + DEBUGLOG(5, "ZSTD_compressContinue (srcSize=%u)", (unsigned)srcSize); + return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 1 /* frame mode */, 0 /* last chunk */); + } + ++/* NOTE: Must just wrap ZSTD_compressContinue_public() */ ++size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize) ++{ ++ return ZSTD_compressContinue_public(cctx, dst, dstCapacity, src, srcSize); ++} + +-size_t ZSTD_getBlockSize(const ZSTD_CCtx* cctx) ++static size_t ZSTD_getBlockSize_deprecated(const ZSTD_CCtx* cctx) + { + ZSTD_compressionParameters const cParams = cctx->appliedParams.cParams; + assert(!ZSTD_checkCParams(cParams)); +- return MIN (ZSTD_BLOCKSIZE_MAX, (U32)1 << cParams.windowLog); ++ return MIN(cctx->appliedParams.maxBlockSize, (size_t)1 << cParams.windowLog); + } + +-size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) ++/* NOTE: Must just wrap ZSTD_getBlockSize_deprecated() */ ++size_t ZSTD_getBlockSize(const ZSTD_CCtx* cctx) ++{ ++ return ZSTD_getBlockSize_deprecated(cctx); ++} ++ ++/* NOTE: Must just wrap ZSTD_compressBlock_deprecated() */ ++size_t ZSTD_compressBlock_deprecated(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) + { + DEBUGLOG(5, "ZSTD_compressBlock: srcSize = %u", (unsigned)srcSize); +- { size_t const blockSizeMax = ZSTD_getBlockSize(cctx); ++ { size_t const blockSizeMax = ZSTD_getBlockSize_deprecated(cctx); + RETURN_ERROR_IF(srcSize > blockSizeMax, srcSize_wrong, "input is larger than a block"); } + + return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 0 /* frame mode */, 0 /* last chunk */); + } + ++/* NOTE: Must just wrap ZSTD_compressBlock_deprecated() */ ++size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) ++{ ++ return ZSTD_compressBlock_deprecated(cctx, dst, dstCapacity, src, srcSize); ++} ++ + /*! ZSTD_loadDictionaryContent() : + * @return : 0, or an error code + */ +@@ -4111,25 +4691,36 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, + ZSTD_cwksp* ws, + ZSTD_CCtx_params const* params, + const void* src, size_t srcSize, +- ZSTD_dictTableLoadMethod_e dtlm) ++ ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp) + { + const BYTE* ip = (const BYTE*) src; + const BYTE* const iend = ip + srcSize; + int const loadLdmDict = params->ldmParams.enableLdm == ZSTD_ps_enable && ls != NULL; + +- /* Assert that we the ms params match the params we're being given */ ++ /* Assert that the ms params match the params we're being given */ + ZSTD_assertEqualCParams(params->cParams, ms->cParams); + +- if (srcSize > ZSTD_CHUNKSIZE_MAX) { ++ { /* Ensure large dictionaries can't cause index overflow */ ++ + /* Allow the dictionary to set indices up to exactly ZSTD_CURRENT_MAX. + * Dictionaries right at the edge will immediately trigger overflow + * correction, but I don't want to insert extra constraints here. + */ +- U32 const maxDictSize = ZSTD_CURRENT_MAX - 1; +- /* We must have cleared our windows when our source is this large. */ +- assert(ZSTD_window_isEmpty(ms->window)); +- if (loadLdmDict) +- assert(ZSTD_window_isEmpty(ls->window)); ++ U32 maxDictSize = ZSTD_CURRENT_MAX - ZSTD_WINDOW_START_INDEX; ++ ++ int const CDictTaggedIndices = ZSTD_CDictIndicesAreTagged(¶ms->cParams); ++ if (CDictTaggedIndices && tfp == ZSTD_tfp_forCDict) { ++ /* Some dictionary matchfinders in zstd use "short cache", ++ * which treats the lower ZSTD_SHORT_CACHE_TAG_BITS of each ++ * CDict hashtable entry as a tag rather than as part of an index. ++ * When short cache is used, we need to truncate the dictionary ++ * so that its indices don't overlap with the tag. */ ++ U32 const shortCacheMaxDictSize = (1u << (32 - ZSTD_SHORT_CACHE_TAG_BITS)) - ZSTD_WINDOW_START_INDEX; ++ maxDictSize = MIN(maxDictSize, shortCacheMaxDictSize); ++ assert(!loadLdmDict); ++ } ++ + /* If the dictionary is too large, only load the suffix of the dictionary. */ + if (srcSize > maxDictSize) { + ip = iend - maxDictSize; +@@ -4138,35 +4729,58 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, + } + } + +- DEBUGLOG(4, "ZSTD_loadDictionaryContent(): useRowMatchFinder=%d", (int)params->useRowMatchFinder); ++ if (srcSize > ZSTD_CHUNKSIZE_MAX) { ++ /* We must have cleared our windows when our source is this large. */ ++ assert(ZSTD_window_isEmpty(ms->window)); ++ if (loadLdmDict) assert(ZSTD_window_isEmpty(ls->window)); ++ } + ZSTD_window_update(&ms->window, src, srcSize, /* forceNonContiguous */ 0); +- ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base); +- ms->forceNonContiguous = params->deterministicRefPrefix; + +- if (loadLdmDict) { ++ DEBUGLOG(4, "ZSTD_loadDictionaryContent(): useRowMatchFinder=%d", (int)params->useRowMatchFinder); ++ ++ if (loadLdmDict) { /* Load the entire dict into LDM matchfinders. */ + ZSTD_window_update(&ls->window, src, srcSize, /* forceNonContiguous */ 0); + ls->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ls->window.base); ++ ZSTD_ldm_fillHashTable(ls, ip, iend, ¶ms->ldmParams); + } + ++ /* If the dict is larger than we can reasonably index in our tables, only load the suffix. */ ++ if (params->cParams.strategy < ZSTD_btultra) { ++ U32 maxDictSize = 8U << MIN(MAX(params->cParams.hashLog, params->cParams.chainLog), 28); ++ if (srcSize > maxDictSize) { ++ ip = iend - maxDictSize; ++ src = ip; ++ srcSize = maxDictSize; ++ } ++ } ++ ++ ms->nextToUpdate = (U32)(ip - ms->window.base); ++ ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base); ++ ms->forceNonContiguous = params->deterministicRefPrefix; ++ + if (srcSize <= HASH_READ_SIZE) return 0; + + ZSTD_overflowCorrectIfNeeded(ms, ws, params, ip, iend); + +- if (loadLdmDict) +- ZSTD_ldm_fillHashTable(ls, ip, iend, ¶ms->ldmParams); +- + switch(params->cParams.strategy) + { + case ZSTD_fast: +- ZSTD_fillHashTable(ms, iend, dtlm); ++ ZSTD_fillHashTable(ms, iend, dtlm, tfp); + break; + case ZSTD_dfast: +- ZSTD_fillDoubleHashTable(ms, iend, dtlm); ++#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR ++ ZSTD_fillDoubleHashTable(ms, iend, dtlm, tfp); ++#else ++ assert(0); /* shouldn't be called: cparams should've been adjusted. */ ++#endif + break; + + case ZSTD_greedy: + case ZSTD_lazy: + case ZSTD_lazy2: ++#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) + assert(srcSize >= HASH_READ_SIZE); + if (ms->dedicatedDictSearch) { + assert(ms->chainTable != NULL); +@@ -4174,7 +4788,7 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, + } else { + assert(params->useRowMatchFinder != ZSTD_ps_auto); + if (params->useRowMatchFinder == ZSTD_ps_enable) { +- size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog) * sizeof(U16); ++ size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog); + ZSTD_memset(ms->tagTable, 0, tagTableSize); + ZSTD_row_update(ms, iend-HASH_READ_SIZE); + DEBUGLOG(4, "Using row-based hash table for lazy dict"); +@@ -4183,14 +4797,23 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, + DEBUGLOG(4, "Using chain-based hash table for lazy dict"); + } + } ++#else ++ assert(0); /* shouldn't be called: cparams should've been adjusted. */ ++#endif + break; + + case ZSTD_btlazy2: /* we want the dictionary table fully sorted */ + case ZSTD_btopt: + case ZSTD_btultra: + case ZSTD_btultra2: ++#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) + assert(srcSize >= HASH_READ_SIZE); + ZSTD_updateTree(ms, iend-HASH_READ_SIZE, iend); ++#else ++ assert(0); /* shouldn't be called: cparams should've been adjusted. */ ++#endif + break; + + default: +@@ -4237,11 +4860,10 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, + + /* We only set the loaded table as valid if it contains all non-zero + * weights. Otherwise, we set it to check */ +- if (!hasZeroWeights) ++ if (!hasZeroWeights && maxSymbolValue == 255) + bs->entropy.huf.repeatMode = HUF_repeat_valid; + + RETURN_ERROR_IF(HUF_isError(hufHeaderSize), dictionary_corrupted, ""); +- RETURN_ERROR_IF(maxSymbolValue < 255, dictionary_corrupted, ""); + dictPtr += hufHeaderSize; + } + +@@ -4327,6 +4949,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, + ZSTD_CCtx_params const* params, + const void* dict, size_t dictSize, + ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp, + void* workspace) + { + const BYTE* dictPtr = (const BYTE*)dict; +@@ -4345,7 +4968,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, + { + size_t const dictContentSize = (size_t)(dictEnd - dictPtr); + FORWARD_IF_ERROR(ZSTD_loadDictionaryContent( +- ms, NULL, ws, params, dictPtr, dictContentSize, dtlm), ""); ++ ms, NULL, ws, params, dictPtr, dictContentSize, dtlm, tfp), ""); + } + return dictID; + } +@@ -4361,6 +4984,7 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, + const void* dict, size_t dictSize, + ZSTD_dictContentType_e dictContentType, + ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp, + void* workspace) + { + DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize); +@@ -4373,13 +4997,13 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, + + /* dict restricted modes */ + if (dictContentType == ZSTD_dct_rawContent) +- return ZSTD_loadDictionaryContent(ms, ls, ws, params, dict, dictSize, dtlm); ++ return ZSTD_loadDictionaryContent(ms, ls, ws, params, dict, dictSize, dtlm, tfp); + + if (MEM_readLE32(dict) != ZSTD_MAGIC_DICTIONARY) { + if (dictContentType == ZSTD_dct_auto) { + DEBUGLOG(4, "raw content dictionary detected"); + return ZSTD_loadDictionaryContent( +- ms, ls, ws, params, dict, dictSize, dtlm); ++ ms, ls, ws, params, dict, dictSize, dtlm, tfp); + } + RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong, ""); + assert(0); /* impossible */ +@@ -4387,13 +5011,14 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, + + /* dict as full zstd dictionary */ + return ZSTD_loadZstdDictionary( +- bs, ms, ws, params, dict, dictSize, dtlm, workspace); ++ bs, ms, ws, params, dict, dictSize, dtlm, tfp, workspace); + } + + #define ZSTD_USE_CDICT_PARAMS_SRCSIZE_CUTOFF (128 KB) + #define ZSTD_USE_CDICT_PARAMS_DICTSIZE_MULTIPLIER (6ULL) + + /*! ZSTD_compressBegin_internal() : ++ * Assumption : either @dict OR @cdict (or none) is non-NULL, never both + * @return : 0, or an error code */ + static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, + const void* dict, size_t dictSize, +@@ -4426,11 +5051,11 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, + cctx->blockState.prevCBlock, &cctx->blockState.matchState, + &cctx->ldmState, &cctx->workspace, &cctx->appliedParams, cdict->dictContent, + cdict->dictContentSize, cdict->dictContentType, dtlm, +- cctx->entropyWorkspace) ++ ZSTD_tfp_forCCtx, cctx->entropyWorkspace) + : ZSTD_compress_insertDictionary( + cctx->blockState.prevCBlock, &cctx->blockState.matchState, + &cctx->ldmState, &cctx->workspace, &cctx->appliedParams, dict, dictSize, +- dictContentType, dtlm, cctx->entropyWorkspace); ++ dictContentType, dtlm, ZSTD_tfp_forCCtx, cctx->entropyWorkspace); + FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed"); + assert(dictID <= UINT_MAX); + cctx->dictID = (U32)dictID; +@@ -4471,11 +5096,11 @@ size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, + &cctxParams, pledgedSrcSize); + } + +-size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) ++static size_t ++ZSTD_compressBegin_usingDict_deprecated(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) + { + ZSTD_CCtx_params cctxParams; +- { +- ZSTD_parameters const params = ZSTD_getParams_internal(compressionLevel, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_noAttachDict); ++ { ZSTD_parameters const params = ZSTD_getParams_internal(compressionLevel, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_noAttachDict); + ZSTD_CCtxParams_init_internal(&cctxParams, ¶ms, (compressionLevel == 0) ? ZSTD_CLEVEL_DEFAULT : compressionLevel); + } + DEBUGLOG(4, "ZSTD_compressBegin_usingDict (dictSize=%u)", (unsigned)dictSize); +@@ -4483,9 +5108,15 @@ size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t di + &cctxParams, ZSTD_CONTENTSIZE_UNKNOWN, ZSTDb_not_buffered); + } + ++size_t ++ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) ++{ ++ return ZSTD_compressBegin_usingDict_deprecated(cctx, dict, dictSize, compressionLevel); ++} ++ + size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel) + { +- return ZSTD_compressBegin_usingDict(cctx, NULL, 0, compressionLevel); ++ return ZSTD_compressBegin_usingDict_deprecated(cctx, NULL, 0, compressionLevel); + } + + +@@ -4496,14 +5127,13 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) + { + BYTE* const ostart = (BYTE*)dst; + BYTE* op = ostart; +- size_t fhSize = 0; + + DEBUGLOG(4, "ZSTD_writeEpilogue"); + RETURN_ERROR_IF(cctx->stage == ZSTDcs_created, stage_wrong, "init missing"); + + /* special case : empty frame */ + if (cctx->stage == ZSTDcs_init) { +- fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, &cctx->appliedParams, 0, 0); ++ size_t fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, &cctx->appliedParams, 0, 0); + FORWARD_IF_ERROR(fhSize, "ZSTD_writeFrameHeader failed"); + dstCapacity -= fhSize; + op += fhSize; +@@ -4513,8 +5143,9 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) + if (cctx->stage != ZSTDcs_ending) { + /* write one last empty block, make it the "last" block */ + U32 const cBlockHeader24 = 1 /* last block */ + (((U32)bt_raw)<<1) + 0; +- RETURN_ERROR_IF(dstCapacity<4, dstSize_tooSmall, "no room for epilogue"); +- MEM_writeLE32(op, cBlockHeader24); ++ ZSTD_STATIC_ASSERT(ZSTD_BLOCKHEADERSIZE == 3); ++ RETURN_ERROR_IF(dstCapacity<3, dstSize_tooSmall, "no room for epilogue"); ++ MEM_writeLE24(op, cBlockHeader24); + op += ZSTD_blockHeaderSize; + dstCapacity -= ZSTD_blockHeaderSize; + } +@@ -4537,9 +5168,9 @@ void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize) + (void)extraCSize; + } + +-size_t ZSTD_compressEnd (ZSTD_CCtx* cctx, +- void* dst, size_t dstCapacity, +- const void* src, size_t srcSize) ++size_t ZSTD_compressEnd_public(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize) + { + size_t endResult; + size_t const cSize = ZSTD_compressContinue_internal(cctx, +@@ -4563,6 +5194,14 @@ size_t ZSTD_compressEnd (ZSTD_CCtx* cctx, + return cSize + endResult; + } + ++/* NOTE: Must just wrap ZSTD_compressEnd_public() */ ++size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize) ++{ ++ return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); ++} ++ + size_t ZSTD_compress_advanced (ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, +@@ -4591,7 +5230,7 @@ size_t ZSTD_compress_advanced_internal( + FORWARD_IF_ERROR( ZSTD_compressBegin_internal(cctx, + dict, dictSize, ZSTD_dct_auto, ZSTD_dtlm_fast, NULL, + params, srcSize, ZSTDb_not_buffered) , ""); +- return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); ++ return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); + } + + size_t ZSTD_compress_usingDict(ZSTD_CCtx* cctx, +@@ -4709,7 +5348,7 @@ static size_t ZSTD_initCDict_internal( + { size_t const dictID = ZSTD_compress_insertDictionary( + &cdict->cBlockState, &cdict->matchState, NULL, &cdict->workspace, + ¶ms, cdict->dictContent, cdict->dictContentSize, +- dictContentType, ZSTD_dtlm_full, cdict->entropyWorkspace); ++ dictContentType, ZSTD_dtlm_full, ZSTD_tfp_forCDict, cdict->entropyWorkspace); + FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed"); + assert(dictID <= (size_t)(U32)-1); + cdict->dictID = (U32)dictID; +@@ -4811,7 +5450,7 @@ ZSTD_CDict* ZSTD_createCDict_advanced2( + cctxParams.useRowMatchFinder, cctxParams.enableDedicatedDictSearch, + customMem); + +- if (ZSTD_isError( ZSTD_initCDict_internal(cdict, ++ if (!cdict || ZSTD_isError( ZSTD_initCDict_internal(cdict, + dict, dictSize, + dictLoadMethod, dictContentType, + cctxParams) )) { +@@ -4906,6 +5545,7 @@ const ZSTD_CDict* ZSTD_initStaticCDict( + params.cParams = cParams; + params.useRowMatchFinder = useRowMatchFinder; + cdict->useRowMatchFinder = useRowMatchFinder; ++ cdict->compressionLevel = ZSTD_NO_CLEVEL; + + if (ZSTD_isError( ZSTD_initCDict_internal(cdict, + dict, dictSize, +@@ -4985,12 +5625,17 @@ size_t ZSTD_compressBegin_usingCDict_advanced( + + /* ZSTD_compressBegin_usingCDict() : + * cdict must be != NULL */ +-size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) ++size_t ZSTD_compressBegin_usingCDict_deprecated(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) + { + ZSTD_frameParameters const fParams = { 0 /*content*/, 0 /*checksum*/, 0 /*noDictID*/ }; + return ZSTD_compressBegin_usingCDict_internal(cctx, cdict, fParams, ZSTD_CONTENTSIZE_UNKNOWN); + } + ++size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) ++{ ++ return ZSTD_compressBegin_usingCDict_deprecated(cctx, cdict); ++} ++ + /*! ZSTD_compress_usingCDict_internal(): + * Implementation of various ZSTD_compress_usingCDict* functions. + */ +@@ -5000,7 +5645,7 @@ static size_t ZSTD_compress_usingCDict_internal(ZSTD_CCtx* cctx, + const ZSTD_CDict* cdict, ZSTD_frameParameters fParams) + { + FORWARD_IF_ERROR(ZSTD_compressBegin_usingCDict_internal(cctx, cdict, fParams, srcSize), ""); /* will check if cdict != NULL */ +- return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); ++ return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); + } + + /*! ZSTD_compress_usingCDict_advanced(): +@@ -5197,30 +5842,41 @@ size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel) + + static size_t ZSTD_nextInputSizeHint(const ZSTD_CCtx* cctx) + { +- size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos; +- if (hintInSize==0) hintInSize = cctx->blockSize; +- return hintInSize; ++ if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { ++ return cctx->blockSize - cctx->stableIn_notConsumed; ++ } ++ assert(cctx->appliedParams.inBufferMode == ZSTD_bm_buffered); ++ { size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos; ++ if (hintInSize==0) hintInSize = cctx->blockSize; ++ return hintInSize; ++ } + } + + /* ZSTD_compressStream_generic(): + * internal function for all *compressStream*() variants +- * non-static, because can be called from zstdmt_compress.c +- * @return : hint size for next input */ ++ * @return : hint size for next input to complete ongoing block */ + static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + ZSTD_outBuffer* output, + ZSTD_inBuffer* input, + ZSTD_EndDirective const flushMode) + { +- const char* const istart = (const char*)input->src; +- const char* const iend = input->size != 0 ? istart + input->size : istart; +- const char* ip = input->pos != 0 ? istart + input->pos : istart; +- char* const ostart = (char*)output->dst; +- char* const oend = output->size != 0 ? ostart + output->size : ostart; +- char* op = output->pos != 0 ? ostart + output->pos : ostart; ++ const char* const istart = (assert(input != NULL), (const char*)input->src); ++ const char* const iend = (istart != NULL) ? istart + input->size : istart; ++ const char* ip = (istart != NULL) ? istart + input->pos : istart; ++ char* const ostart = (assert(output != NULL), (char*)output->dst); ++ char* const oend = (ostart != NULL) ? ostart + output->size : ostart; ++ char* op = (ostart != NULL) ? ostart + output->pos : ostart; + U32 someMoreWork = 1; + + /* check expectations */ +- DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%u", (unsigned)flushMode); ++ DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%i, srcSize = %zu", (int)flushMode, input->size - input->pos); ++ assert(zcs != NULL); ++ if (zcs->appliedParams.inBufferMode == ZSTD_bm_stable) { ++ assert(input->pos >= zcs->stableIn_notConsumed); ++ input->pos -= zcs->stableIn_notConsumed; ++ if (ip) ip -= zcs->stableIn_notConsumed; ++ zcs->stableIn_notConsumed = 0; ++ } + if (zcs->appliedParams.inBufferMode == ZSTD_bm_buffered) { + assert(zcs->inBuff != NULL); + assert(zcs->inBuffSize > 0); +@@ -5229,8 +5885,10 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + assert(zcs->outBuff != NULL); + assert(zcs->outBuffSize > 0); + } +- assert(output->pos <= output->size); ++ if (input->src == NULL) assert(input->size == 0); + assert(input->pos <= input->size); ++ if (output->dst == NULL) assert(output->size == 0); ++ assert(output->pos <= output->size); + assert((U32)flushMode <= (U32)ZSTD_e_end); + + while (someMoreWork) { +@@ -5245,7 +5903,7 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + || zcs->appliedParams.outBufferMode == ZSTD_bm_stable) /* OR we are allowed to return dstSizeTooSmall */ + && (zcs->inBuffPos == 0) ) { + /* shortcut to compression pass directly into output buffer */ +- size_t const cSize = ZSTD_compressEnd(zcs, ++ size_t const cSize = ZSTD_compressEnd_public(zcs, + op, oend-op, ip, iend-ip); + DEBUGLOG(4, "ZSTD_compressEnd : cSize=%u", (unsigned)cSize); + FORWARD_IF_ERROR(cSize, "ZSTD_compressEnd failed"); +@@ -5262,8 +5920,7 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + zcs->inBuff + zcs->inBuffPos, toLoad, + ip, iend-ip); + zcs->inBuffPos += loaded; +- if (loaded != 0) +- ip += loaded; ++ if (ip) ip += loaded; + if ( (flushMode == ZSTD_e_continue) + && (zcs->inBuffPos < zcs->inBuffTarget) ) { + /* not enough input to fill full block : stop here */ +@@ -5274,6 +5931,20 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + /* empty */ + someMoreWork = 0; break; + } ++ } else { ++ assert(zcs->appliedParams.inBufferMode == ZSTD_bm_stable); ++ if ( (flushMode == ZSTD_e_continue) ++ && ( (size_t)(iend - ip) < zcs->blockSize) ) { ++ /* can't compress a full block : stop here */ ++ zcs->stableIn_notConsumed = (size_t)(iend - ip); ++ ip = iend; /* pretend to have consumed input */ ++ someMoreWork = 0; break; ++ } ++ if ( (flushMode == ZSTD_e_flush) ++ && (ip == iend) ) { ++ /* empty */ ++ someMoreWork = 0; break; ++ } + } + /* compress current block (note : this stage cannot be stopped in the middle) */ + DEBUGLOG(5, "stream compression stage (flushMode==%u)", flushMode); +@@ -5281,9 +5952,8 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + void* cDst; + size_t cSize; + size_t oSize = oend-op; +- size_t const iSize = inputBuffered +- ? zcs->inBuffPos - zcs->inToCompress +- : MIN((size_t)(iend - ip), zcs->blockSize); ++ size_t const iSize = inputBuffered ? zcs->inBuffPos - zcs->inToCompress ++ : MIN((size_t)(iend - ip), zcs->blockSize); + if (oSize >= ZSTD_compressBound(iSize) || zcs->appliedParams.outBufferMode == ZSTD_bm_stable) + cDst = op; /* compress into output buffer, to skip flush stage */ + else +@@ -5291,9 +5961,9 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + if (inputBuffered) { + unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip==iend); + cSize = lastBlock ? +- ZSTD_compressEnd(zcs, cDst, oSize, ++ ZSTD_compressEnd_public(zcs, cDst, oSize, + zcs->inBuff + zcs->inToCompress, iSize) : +- ZSTD_compressContinue(zcs, cDst, oSize, ++ ZSTD_compressContinue_public(zcs, cDst, oSize, + zcs->inBuff + zcs->inToCompress, iSize); + FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed"); + zcs->frameEnded = lastBlock; +@@ -5306,19 +5976,16 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, + if (!lastBlock) + assert(zcs->inBuffTarget <= zcs->inBuffSize); + zcs->inToCompress = zcs->inBuffPos; +- } else { +- unsigned const lastBlock = (ip + iSize == iend); +- assert(flushMode == ZSTD_e_end /* Already validated */); ++ } else { /* !inputBuffered, hence ZSTD_bm_stable */ ++ unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip + iSize == iend); + cSize = lastBlock ? +- ZSTD_compressEnd(zcs, cDst, oSize, ip, iSize) : +- ZSTD_compressContinue(zcs, cDst, oSize, ip, iSize); ++ ZSTD_compressEnd_public(zcs, cDst, oSize, ip, iSize) : ++ ZSTD_compressContinue_public(zcs, cDst, oSize, ip, iSize); + /* Consume the input prior to error checking to mirror buffered mode. */ +- if (iSize > 0) +- ip += iSize; ++ if (ip) ip += iSize; + FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed"); + zcs->frameEnded = lastBlock; +- if (lastBlock) +- assert(ip == iend); ++ if (lastBlock) assert(ip == iend); + } + if (cDst == op) { /* no need to flush */ + op += cSize; +@@ -5388,8 +6055,10 @@ size_t ZSTD_compressStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output, ZSTD_inBuf + /* After a compression call set the expected input/output buffer. + * This is validated at the start of the next compression call. + */ +-static void ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, ZSTD_outBuffer const* output, ZSTD_inBuffer const* input) ++static void ++ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, const ZSTD_outBuffer* output, const ZSTD_inBuffer* input) + { ++ DEBUGLOG(5, "ZSTD_setBufferExpectations (for advanced stable in/out modes)"); + if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { + cctx->expectedInBuffer = *input; + } +@@ -5408,22 +6077,22 @@ static size_t ZSTD_checkBufferStability(ZSTD_CCtx const* cctx, + { + if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { + ZSTD_inBuffer const expect = cctx->expectedInBuffer; +- if (expect.src != input->src || expect.pos != input->pos || expect.size != input->size) +- RETURN_ERROR(srcBuffer_wrong, "ZSTD_c_stableInBuffer enabled but input differs!"); +- if (endOp != ZSTD_e_end) +- RETURN_ERROR(srcBuffer_wrong, "ZSTD_c_stableInBuffer can only be used with ZSTD_e_end!"); ++ if (expect.src != input->src || expect.pos != input->pos) ++ RETURN_ERROR(stabilityCondition_notRespected, "ZSTD_c_stableInBuffer enabled but input differs!"); + } ++ (void)endOp; + if (cctx->appliedParams.outBufferMode == ZSTD_bm_stable) { + size_t const outBufferSize = output->size - output->pos; + if (cctx->expectedOutBufferSize != outBufferSize) +- RETURN_ERROR(dstBuffer_wrong, "ZSTD_c_stableOutBuffer enabled but output size differs!"); ++ RETURN_ERROR(stabilityCondition_notRespected, "ZSTD_c_stableOutBuffer enabled but output size differs!"); + } + return 0; + } + + static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, + ZSTD_EndDirective endOp, +- size_t inSize) { ++ size_t inSize) ++{ + ZSTD_CCtx_params params = cctx->requestedParams; + ZSTD_prefixDict const prefixDict = cctx->prefixDict; + FORWARD_IF_ERROR( ZSTD_initLocalDict(cctx) , ""); /* Init the local dict if present. */ +@@ -5437,9 +6106,9 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, + params.compressionLevel = cctx->cdict->compressionLevel; + } + DEBUGLOG(4, "ZSTD_compressStream2 : transparent init stage"); +- if (endOp == ZSTD_e_end) cctx->pledgedSrcSizePlusOne = inSize + 1; /* auto-fix pledgedSrcSize */ +- { +- size_t const dictSize = prefixDict.dict ++ if (endOp == ZSTD_e_end) cctx->pledgedSrcSizePlusOne = inSize + 1; /* auto-determine pledgedSrcSize */ ++ ++ { size_t const dictSize = prefixDict.dict + ? prefixDict.dictSize + : (cctx->cdict ? cctx->cdict->dictContentSize : 0); + ZSTD_cParamMode_e const mode = ZSTD_getCParamMode(cctx->cdict, ¶ms, cctx->pledgedSrcSizePlusOne - 1); +@@ -5451,6 +6120,9 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, + params.useBlockSplitter = ZSTD_resolveBlockSplitterMode(params.useBlockSplitter, ¶ms.cParams); + params.ldmParams.enableLdm = ZSTD_resolveEnableLdm(params.ldmParams.enableLdm, ¶ms.cParams); + params.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params.useRowMatchFinder, ¶ms.cParams); ++ params.validateSequences = ZSTD_resolveExternalSequenceValidation(params.validateSequences); ++ params.maxBlockSize = ZSTD_resolveMaxBlockSize(params.maxBlockSize); ++ params.searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(params.searchForExternalRepcodes, params.compressionLevel); + + { U64 const pledgedSrcSize = cctx->pledgedSrcSizePlusOne - 1; + assert(!ZSTD_isError(ZSTD_checkCParams(params.cParams))); +@@ -5477,6 +6149,8 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, + return 0; + } + ++/* @return provides a minimum amount of data remaining to be flushed from internal buffers ++ */ + size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, + ZSTD_outBuffer* output, + ZSTD_inBuffer* input, +@@ -5491,8 +6165,27 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, + + /* transparent initialization stage */ + if (cctx->streamStage == zcss_init) { +- FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, input->size), "CompressStream2 initialization failed"); +- ZSTD_setBufferExpectations(cctx, output, input); /* Set initial buffer expectations now that we've initialized */ ++ size_t const inputSize = input->size - input->pos; /* no obligation to start from pos==0 */ ++ size_t const totalInputSize = inputSize + cctx->stableIn_notConsumed; ++ if ( (cctx->requestedParams.inBufferMode == ZSTD_bm_stable) /* input is presumed stable, across invocations */ ++ && (endOp == ZSTD_e_continue) /* no flush requested, more input to come */ ++ && (totalInputSize < ZSTD_BLOCKSIZE_MAX) ) { /* not even reached one block yet */ ++ if (cctx->stableIn_notConsumed) { /* not the first time */ ++ /* check stable source guarantees */ ++ RETURN_ERROR_IF(input->src != cctx->expectedInBuffer.src, stabilityCondition_notRespected, "stableInBuffer condition not respected: wrong src pointer"); ++ RETURN_ERROR_IF(input->pos != cctx->expectedInBuffer.size, stabilityCondition_notRespected, "stableInBuffer condition not respected: externally modified pos"); ++ } ++ /* pretend input was consumed, to give a sense forward progress */ ++ input->pos = input->size; ++ /* save stable inBuffer, for later control, and flush/end */ ++ cctx->expectedInBuffer = *input; ++ /* but actually input wasn't consumed, so keep track of position from where compression shall resume */ ++ cctx->stableIn_notConsumed += inputSize; ++ /* don't initialize yet, wait for the first block of flush() order, for better parameters adaptation */ ++ return ZSTD_FRAMEHEADERSIZE_MIN(cctx->requestedParams.format); /* at least some header to produce */ ++ } ++ FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, totalInputSize), "compressStream2 initialization failed"); ++ ZSTD_setBufferExpectations(cctx, output, input); /* Set initial buffer expectations now that we've initialized */ + } + /* end of transparent initialization stage */ + +@@ -5510,13 +6203,20 @@ size_t ZSTD_compressStream2_simpleArgs ( + const void* src, size_t srcSize, size_t* srcPos, + ZSTD_EndDirective endOp) + { +- ZSTD_outBuffer output = { dst, dstCapacity, *dstPos }; +- ZSTD_inBuffer input = { src, srcSize, *srcPos }; ++ ZSTD_outBuffer output; ++ ZSTD_inBuffer input; ++ output.dst = dst; ++ output.size = dstCapacity; ++ output.pos = *dstPos; ++ input.src = src; ++ input.size = srcSize; ++ input.pos = *srcPos; + /* ZSTD_compressStream2() will check validity of dstPos and srcPos */ +- size_t const cErr = ZSTD_compressStream2(cctx, &output, &input, endOp); +- *dstPos = output.pos; +- *srcPos = input.pos; +- return cErr; ++ { size_t const cErr = ZSTD_compressStream2(cctx, &output, &input, endOp); ++ *dstPos = output.pos; ++ *srcPos = input.pos; ++ return cErr; ++ } + } + + size_t ZSTD_compress2(ZSTD_CCtx* cctx, +@@ -5539,6 +6239,7 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx, + /* Reset to the original values. */ + cctx->requestedParams.inBufferMode = originalInBufferMode; + cctx->requestedParams.outBufferMode = originalOutBufferMode; ++ + FORWARD_IF_ERROR(result, "ZSTD_compressStream2_simpleArgs failed"); + if (result != 0) { /* compression not completed, due to lack of output space */ + assert(oPos == dstCapacity); +@@ -5549,64 +6250,61 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx, + } + } + +-typedef struct { +- U32 idx; /* Index in array of ZSTD_Sequence */ +- U32 posInSequence; /* Position within sequence at idx */ +- size_t posInSrc; /* Number of bytes given by sequences provided so far */ +-} ZSTD_sequencePosition; +- + /* ZSTD_validateSequence() : + * @offCode : is presumed to follow format required by ZSTD_storeSeq() + * @returns a ZSTD error code if sequence is not valid + */ + static size_t +-ZSTD_validateSequence(U32 offCode, U32 matchLength, +- size_t posInSrc, U32 windowLog, size_t dictSize) ++ZSTD_validateSequence(U32 offCode, U32 matchLength, U32 minMatch, ++ size_t posInSrc, U32 windowLog, size_t dictSize, int useSequenceProducer) + { +- U32 const windowSize = 1 << windowLog; ++ U32 const windowSize = 1u << windowLog; + /* posInSrc represents the amount of data the decoder would decode up to this point. + * As long as the amount of data decoded is less than or equal to window size, offsets may be + * larger than the total length of output decoded in order to reference the dict, even larger than + * window size. After output surpasses windowSize, we're limited to windowSize offsets again. + */ + size_t const offsetBound = posInSrc > windowSize ? (size_t)windowSize : posInSrc + (size_t)dictSize; +- RETURN_ERROR_IF(offCode > STORE_OFFSET(offsetBound), corruption_detected, "Offset too large!"); +- RETURN_ERROR_IF(matchLength < MINMATCH, corruption_detected, "Matchlength too small"); ++ size_t const matchLenLowerBound = (minMatch == 3 || useSequenceProducer) ? 3 : 4; ++ RETURN_ERROR_IF(offCode > OFFSET_TO_OFFBASE(offsetBound), externalSequences_invalid, "Offset too large!"); ++ /* Validate maxNbSeq is large enough for the given matchLength and minMatch */ ++ RETURN_ERROR_IF(matchLength < matchLenLowerBound, externalSequences_invalid, "Matchlength too small for the minMatch"); + return 0; + } + + /* Returns an offset code, given a sequence's raw offset, the ongoing repcode array, and whether litLength == 0 */ +-static U32 ZSTD_finalizeOffCode(U32 rawOffset, const U32 rep[ZSTD_REP_NUM], U32 ll0) ++static U32 ZSTD_finalizeOffBase(U32 rawOffset, const U32 rep[ZSTD_REP_NUM], U32 ll0) + { +- U32 offCode = STORE_OFFSET(rawOffset); ++ U32 offBase = OFFSET_TO_OFFBASE(rawOffset); + + if (!ll0 && rawOffset == rep[0]) { +- offCode = STORE_REPCODE_1; ++ offBase = REPCODE1_TO_OFFBASE; + } else if (rawOffset == rep[1]) { +- offCode = STORE_REPCODE(2 - ll0); ++ offBase = REPCODE_TO_OFFBASE(2 - ll0); + } else if (rawOffset == rep[2]) { +- offCode = STORE_REPCODE(3 - ll0); ++ offBase = REPCODE_TO_OFFBASE(3 - ll0); + } else if (ll0 && rawOffset == rep[0] - 1) { +- offCode = STORE_REPCODE_3; ++ offBase = REPCODE3_TO_OFFBASE; + } +- return offCode; ++ return offBase; + } + +-/* Returns 0 on success, and a ZSTD_error otherwise. This function scans through an array of +- * ZSTD_Sequence, storing the sequences it finds, until it reaches a block delimiter. +- */ +-static size_t ++size_t + ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, + ZSTD_sequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, +- const void* src, size_t blockSize) ++ const void* src, size_t blockSize, ++ ZSTD_paramSwitch_e externalRepSearch) + { + U32 idx = seqPos->idx; ++ U32 const startIdx = idx; + BYTE const* ip = (BYTE const*)(src); + const BYTE* const iend = ip + blockSize; + repcodes_t updatedRepcodes; + U32 dictSize; + ++ DEBUGLOG(5, "ZSTD_copySequencesToSeqStoreExplicitBlockDelim (blockSize = %zu)", blockSize); ++ + if (cctx->cdict) { + dictSize = (U32)cctx->cdict->dictContentSize; + } else if (cctx->prefixDict.dict) { +@@ -5615,25 +6313,55 @@ ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, + dictSize = 0; + } + ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t)); +- for (; (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0) && idx < inSeqsSize; ++idx) { ++ for (; idx < inSeqsSize && (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0); ++idx) { + U32 const litLength = inSeqs[idx].litLength; +- U32 const ll0 = (litLength == 0); + U32 const matchLength = inSeqs[idx].matchLength; +- U32 const offCode = ZSTD_finalizeOffCode(inSeqs[idx].offset, updatedRepcodes.rep, ll0); +- ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0); ++ U32 offBase; ++ ++ if (externalRepSearch == ZSTD_ps_disable) { ++ offBase = OFFSET_TO_OFFBASE(inSeqs[idx].offset); ++ } else { ++ U32 const ll0 = (litLength == 0); ++ offBase = ZSTD_finalizeOffBase(inSeqs[idx].offset, updatedRepcodes.rep, ll0); ++ ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); ++ } + +- DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength); ++ DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); + if (cctx->appliedParams.validateSequences) { + seqPos->posInSrc += litLength + matchLength; +- FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, matchLength, seqPos->posInSrc, +- cctx->appliedParams.cParams.windowLog, dictSize), ++ FORWARD_IF_ERROR(ZSTD_validateSequence(offBase, matchLength, cctx->appliedParams.cParams.minMatch, seqPos->posInSrc, ++ cctx->appliedParams.cParams.windowLog, dictSize, ZSTD_hasExtSeqProd(&cctx->appliedParams)), + "Sequence validation failed"); + } +- RETURN_ERROR_IF(idx - seqPos->idx > cctx->seqStore.maxNbSeq, memory_allocation, ++ RETURN_ERROR_IF(idx - seqPos->idx >= cctx->seqStore.maxNbSeq, externalSequences_invalid, + "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); +- ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offCode, matchLength); ++ ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offBase, matchLength); + ip += matchLength + litLength; + } ++ ++ /* If we skipped repcode search while parsing, we need to update repcodes now */ ++ assert(externalRepSearch != ZSTD_ps_auto); ++ assert(idx >= startIdx); ++ if (externalRepSearch == ZSTD_ps_disable && idx != startIdx) { ++ U32* const rep = updatedRepcodes.rep; ++ U32 lastSeqIdx = idx - 1; /* index of last non-block-delimiter sequence */ ++ ++ if (lastSeqIdx >= startIdx + 2) { ++ rep[2] = inSeqs[lastSeqIdx - 2].offset; ++ rep[1] = inSeqs[lastSeqIdx - 1].offset; ++ rep[0] = inSeqs[lastSeqIdx].offset; ++ } else if (lastSeqIdx == startIdx + 1) { ++ rep[2] = rep[0]; ++ rep[1] = inSeqs[lastSeqIdx - 1].offset; ++ rep[0] = inSeqs[lastSeqIdx].offset; ++ } else { ++ assert(lastSeqIdx == startIdx); ++ rep[2] = rep[1]; ++ rep[1] = rep[0]; ++ rep[0] = inSeqs[lastSeqIdx].offset; ++ } ++ } ++ + ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(repcodes_t)); + + if (inSeqs[idx].litLength) { +@@ -5642,26 +6370,15 @@ ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, + ip += inSeqs[idx].litLength; + seqPos->posInSrc += inSeqs[idx].litLength; + } +- RETURN_ERROR_IF(ip != iend, corruption_detected, "Blocksize doesn't agree with block delimiter!"); ++ RETURN_ERROR_IF(ip != iend, externalSequences_invalid, "Blocksize doesn't agree with block delimiter!"); + seqPos->idx = idx+1; + return 0; + } + +-/* Returns the number of bytes to move the current read position back by. Only non-zero +- * if we ended up splitting a sequence. Otherwise, it may return a ZSTD error if something +- * went wrong. +- * +- * This function will attempt to scan through blockSize bytes represented by the sequences +- * in inSeqs, storing any (partial) sequences. +- * +- * Occasionally, we may want to change the actual number of bytes we consumed from inSeqs to +- * avoid splitting a match, or to avoid splitting a match such that it would produce a match +- * smaller than MINMATCH. In this case, we return the number of bytes that we didn't read from this block. +- */ +-static size_t ++size_t + ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, +- const void* src, size_t blockSize) ++ const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch) + { + U32 idx = seqPos->idx; + U32 startPosInSequence = seqPos->posInSequence; +@@ -5673,6 +6390,9 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* + U32 bytesAdjustment = 0; + U32 finalMatchSplit = 0; + ++ /* TODO(embg) support fast parsing mode in noBlockDelim mode */ ++ (void)externalRepSearch; ++ + if (cctx->cdict) { + dictSize = cctx->cdict->dictContentSize; + } else if (cctx->prefixDict.dict) { +@@ -5680,7 +6400,7 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* + } else { + dictSize = 0; + } +- DEBUGLOG(5, "ZSTD_copySequencesToSeqStore: idx: %u PIS: %u blockSize: %zu", idx, startPosInSequence, blockSize); ++ DEBUGLOG(5, "ZSTD_copySequencesToSeqStoreNoBlockDelim: idx: %u PIS: %u blockSize: %zu", idx, startPosInSequence, blockSize); + DEBUGLOG(5, "Start seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength); + ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t)); + while (endPosInSequence && idx < inSeqsSize && !finalMatchSplit) { +@@ -5688,7 +6408,7 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* + U32 litLength = currSeq.litLength; + U32 matchLength = currSeq.matchLength; + U32 const rawOffset = currSeq.offset; +- U32 offCode; ++ U32 offBase; + + /* Modify the sequence depending on where endPosInSequence lies */ + if (endPosInSequence >= currSeq.litLength + currSeq.matchLength) { +@@ -5702,7 +6422,6 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* + /* Move to the next sequence */ + endPosInSequence -= currSeq.litLength + currSeq.matchLength; + startPosInSequence = 0; +- idx++; + } else { + /* This is the final (partial) sequence we're adding from inSeqs, and endPosInSequence + does not reach the end of the match. So, we have to split the sequence */ +@@ -5742,21 +6461,23 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* + } + /* Check if this offset can be represented with a repcode */ + { U32 const ll0 = (litLength == 0); +- offCode = ZSTD_finalizeOffCode(rawOffset, updatedRepcodes.rep, ll0); +- ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0); ++ offBase = ZSTD_finalizeOffBase(rawOffset, updatedRepcodes.rep, ll0); ++ ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); + } + + if (cctx->appliedParams.validateSequences) { + seqPos->posInSrc += litLength + matchLength; +- FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, matchLength, seqPos->posInSrc, +- cctx->appliedParams.cParams.windowLog, dictSize), ++ FORWARD_IF_ERROR(ZSTD_validateSequence(offBase, matchLength, cctx->appliedParams.cParams.minMatch, seqPos->posInSrc, ++ cctx->appliedParams.cParams.windowLog, dictSize, ZSTD_hasExtSeqProd(&cctx->appliedParams)), + "Sequence validation failed"); + } +- DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength); +- RETURN_ERROR_IF(idx - seqPos->idx > cctx->seqStore.maxNbSeq, memory_allocation, ++ DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); ++ RETURN_ERROR_IF(idx - seqPos->idx >= cctx->seqStore.maxNbSeq, externalSequences_invalid, + "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); +- ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offCode, matchLength); ++ ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offBase, matchLength); + ip += matchLength + litLength; ++ if (!finalMatchSplit) ++ idx++; /* Next Sequence */ + } + DEBUGLOG(5, "Ending seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength); + assert(idx == inSeqsSize || endPosInSequence <= inSeqs[idx].litLength + inSeqs[idx].matchLength); +@@ -5779,7 +6500,7 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* + + typedef size_t (*ZSTD_sequenceCopier) (ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, +- const void* src, size_t blockSize); ++ const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch); + static ZSTD_sequenceCopier ZSTD_selectSequenceCopier(ZSTD_sequenceFormat_e mode) + { + ZSTD_sequenceCopier sequenceCopier = NULL; +@@ -5793,6 +6514,57 @@ static ZSTD_sequenceCopier ZSTD_selectSequenceCopier(ZSTD_sequenceFormat_e mode) + return sequenceCopier; + } + ++/* Discover the size of next block by searching for the delimiter. ++ * Note that a block delimiter **must** exist in this mode, ++ * otherwise it's an input error. ++ * The block size retrieved will be later compared to ensure it remains within bounds */ ++static size_t ++blockSize_explicitDelimiter(const ZSTD_Sequence* inSeqs, size_t inSeqsSize, ZSTD_sequencePosition seqPos) ++{ ++ int end = 0; ++ size_t blockSize = 0; ++ size_t spos = seqPos.idx; ++ DEBUGLOG(6, "blockSize_explicitDelimiter : seq %zu / %zu", spos, inSeqsSize); ++ assert(spos <= inSeqsSize); ++ while (spos < inSeqsSize) { ++ end = (inSeqs[spos].offset == 0); ++ blockSize += inSeqs[spos].litLength + inSeqs[spos].matchLength; ++ if (end) { ++ if (inSeqs[spos].matchLength != 0) ++ RETURN_ERROR(externalSequences_invalid, "delimiter format error : both matchlength and offset must be == 0"); ++ break; ++ } ++ spos++; ++ } ++ if (!end) ++ RETURN_ERROR(externalSequences_invalid, "Reached end of sequences without finding a block delimiter"); ++ return blockSize; ++} ++ ++/* More a "target" block size */ ++static size_t blockSize_noDelimiter(size_t blockSize, size_t remaining) ++{ ++ int const lastBlock = (remaining <= blockSize); ++ return lastBlock ? remaining : blockSize; ++} ++ ++static size_t determine_blockSize(ZSTD_sequenceFormat_e mode, ++ size_t blockSize, size_t remaining, ++ const ZSTD_Sequence* inSeqs, size_t inSeqsSize, ZSTD_sequencePosition seqPos) ++{ ++ DEBUGLOG(6, "determine_blockSize : remainingSize = %zu", remaining); ++ if (mode == ZSTD_sf_noBlockDelimiters) ++ return blockSize_noDelimiter(blockSize, remaining); ++ { size_t const explicitBlockSize = blockSize_explicitDelimiter(inSeqs, inSeqsSize, seqPos); ++ FORWARD_IF_ERROR(explicitBlockSize, "Error while determining block size with explicit delimiters"); ++ if (explicitBlockSize > blockSize) ++ RETURN_ERROR(externalSequences_invalid, "sequences incorrectly define a too large block"); ++ if (explicitBlockSize > remaining) ++ RETURN_ERROR(externalSequences_invalid, "sequences define a frame longer than source"); ++ return explicitBlockSize; ++ } ++} ++ + /* Compress, block-by-block, all of the sequences given. + * + * Returns the cumulative size of all compressed blocks (including their headers), +@@ -5805,9 +6577,6 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + const void* src, size_t srcSize) + { + size_t cSize = 0; +- U32 lastBlock; +- size_t blockSize; +- size_t compressedSeqsSize; + size_t remaining = srcSize; + ZSTD_sequencePosition seqPos = {0, 0, 0}; + +@@ -5827,22 +6596,29 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + } + + while (remaining) { ++ size_t compressedSeqsSize; + size_t cBlockSize; + size_t additionalByteAdjustment; +- lastBlock = remaining <= cctx->blockSize; +- blockSize = lastBlock ? (U32)remaining : (U32)cctx->blockSize; ++ size_t blockSize = determine_blockSize(cctx->appliedParams.blockDelimiters, ++ cctx->blockSize, remaining, ++ inSeqs, inSeqsSize, seqPos); ++ U32 const lastBlock = (blockSize == remaining); ++ FORWARD_IF_ERROR(blockSize, "Error while trying to determine block size"); ++ assert(blockSize <= remaining); + ZSTD_resetSeqStore(&cctx->seqStore); +- DEBUGLOG(4, "Working on new block. Blocksize: %zu", blockSize); ++ DEBUGLOG(5, "Working on new block. Blocksize: %zu (total:%zu)", blockSize, (ip - (const BYTE*)src) + blockSize); + +- additionalByteAdjustment = sequenceCopier(cctx, &seqPos, inSeqs, inSeqsSize, ip, blockSize); ++ additionalByteAdjustment = sequenceCopier(cctx, &seqPos, inSeqs, inSeqsSize, ip, blockSize, cctx->appliedParams.searchForExternalRepcodes); + FORWARD_IF_ERROR(additionalByteAdjustment, "Bad sequence copy"); + blockSize -= additionalByteAdjustment; + + /* If blocks are too small, emit as a nocompress block */ +- if (blockSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1) { ++ /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding ++ * additional 1. We need to revisit and change this logic to be more consistent */ ++ if (blockSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1+1) { + cBlockSize = ZSTD_noCompressBlock(op, dstCapacity, ip, blockSize, lastBlock); + FORWARD_IF_ERROR(cBlockSize, "Nocompress block failed"); +- DEBUGLOG(4, "Block too small, writing out nocompress block: cSize: %zu", cBlockSize); ++ DEBUGLOG(5, "Block too small, writing out nocompress block: cSize: %zu", cBlockSize); + cSize += cBlockSize; + ip += blockSize; + op += cBlockSize; +@@ -5851,6 +6627,7 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + continue; + } + ++ RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize, dstSize_tooSmall, "not enough dstCapacity to write a new compressed block"); + compressedSeqsSize = ZSTD_entropyCompressSeqStore(&cctx->seqStore, + &cctx->blockState.prevCBlock->entropy, &cctx->blockState.nextCBlock->entropy, + &cctx->appliedParams, +@@ -5859,11 +6636,11 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + cctx->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */, + cctx->bmi2); + FORWARD_IF_ERROR(compressedSeqsSize, "Compressing sequences of block failed"); +- DEBUGLOG(4, "Compressed sequences size: %zu", compressedSeqsSize); ++ DEBUGLOG(5, "Compressed sequences size: %zu", compressedSeqsSize); + + if (!cctx->isFirstBlock && + ZSTD_maybeRLE(&cctx->seqStore) && +- ZSTD_isRLE((BYTE const*)src, srcSize)) { ++ ZSTD_isRLE(ip, blockSize)) { + /* We don't want to emit our first block as a RLE even if it qualifies because + * doing so will cause the decoder (cli only) to throw a "should consume all input error." + * This is only an issue for zstd <= v1.4.3 +@@ -5874,12 +6651,12 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + if (compressedSeqsSize == 0) { + /* ZSTD_noCompressBlock writes the block header as well */ + cBlockSize = ZSTD_noCompressBlock(op, dstCapacity, ip, blockSize, lastBlock); +- FORWARD_IF_ERROR(cBlockSize, "Nocompress block failed"); +- DEBUGLOG(4, "Writing out nocompress block, size: %zu", cBlockSize); ++ FORWARD_IF_ERROR(cBlockSize, "ZSTD_noCompressBlock failed"); ++ DEBUGLOG(5, "Writing out nocompress block, size: %zu", cBlockSize); + } else if (compressedSeqsSize == 1) { + cBlockSize = ZSTD_rleCompressBlock(op, dstCapacity, *ip, blockSize, lastBlock); +- FORWARD_IF_ERROR(cBlockSize, "RLE compress block failed"); +- DEBUGLOG(4, "Writing out RLE block, size: %zu", cBlockSize); ++ FORWARD_IF_ERROR(cBlockSize, "ZSTD_rleCompressBlock failed"); ++ DEBUGLOG(5, "Writing out RLE block, size: %zu", cBlockSize); + } else { + U32 cBlockHeader; + /* Error checking and repcodes update */ +@@ -5891,11 +6668,10 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + cBlockHeader = lastBlock + (((U32)bt_compressed)<<1) + (U32)(compressedSeqsSize << 3); + MEM_writeLE24(op, cBlockHeader); + cBlockSize = ZSTD_blockHeaderSize + compressedSeqsSize; +- DEBUGLOG(4, "Writing out compressed block, size: %zu", cBlockSize); ++ DEBUGLOG(5, "Writing out compressed block, size: %zu", cBlockSize); + } + + cSize += cBlockSize; +- DEBUGLOG(4, "cSize running total: %zu", cSize); + + if (lastBlock) { + break; +@@ -5906,12 +6682,15 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, + dstCapacity -= cBlockSize; + cctx->isFirstBlock = 0; + } ++ DEBUGLOG(5, "cSize running total: %zu (remaining dstCapacity=%zu)", cSize, dstCapacity); + } + ++ DEBUGLOG(4, "cSize final total: %zu", cSize); + return cSize; + } + +-size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapacity, ++size_t ZSTD_compressSequences(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, + const ZSTD_Sequence* inSeqs, size_t inSeqsSize, + const void* src, size_t srcSize) + { +@@ -5921,7 +6700,7 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapaci + size_t frameHeaderSize = 0; + + /* Transparent initialization stage, same as compressStream2() */ +- DEBUGLOG(3, "ZSTD_compressSequences()"); ++ DEBUGLOG(4, "ZSTD_compressSequences (dstCapacity=%zu)", dstCapacity); + assert(cctx != NULL); + FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, ZSTD_e_end, srcSize), "CCtx initialization failed"); + /* Begin writing output, starting with frame header */ +@@ -5949,26 +6728,34 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapaci + cSize += 4; + } + +- DEBUGLOG(3, "Final compressed size: %zu", cSize); ++ DEBUGLOG(4, "Final compressed size: %zu", cSize); + return cSize; + } + + /*====== Finalize ======*/ + ++static ZSTD_inBuffer inBuffer_forEndFlush(const ZSTD_CStream* zcs) ++{ ++ const ZSTD_inBuffer nullInput = { NULL, 0, 0 }; ++ const int stableInput = (zcs->appliedParams.inBufferMode == ZSTD_bm_stable); ++ return stableInput ? zcs->expectedInBuffer : nullInput; ++} ++ + /*! ZSTD_flushStream() : + * @return : amount of data remaining to flush */ + size_t ZSTD_flushStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output) + { +- ZSTD_inBuffer input = { NULL, 0, 0 }; ++ ZSTD_inBuffer input = inBuffer_forEndFlush(zcs); ++ input.size = input.pos; /* do not ingest more input during flush */ + return ZSTD_compressStream2(zcs, output, &input, ZSTD_e_flush); + } + + + size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output) + { +- ZSTD_inBuffer input = { NULL, 0, 0 }; ++ ZSTD_inBuffer input = inBuffer_forEndFlush(zcs); + size_t const remainingToFlush = ZSTD_compressStream2(zcs, output, &input, ZSTD_e_end); +- FORWARD_IF_ERROR( remainingToFlush , "ZSTD_compressStream2 failed"); ++ FORWARD_IF_ERROR(remainingToFlush , "ZSTD_compressStream2(,,ZSTD_e_end) failed"); + if (zcs->appliedParams.nbWorkers > 0) return remainingToFlush; /* minimal estimation */ + /* single thread mode : attempt to calculate remaining to flush more precisely */ + { size_t const lastBlockSize = zcs->frameEnded ? 0 : ZSTD_BLOCKHEADERSIZE; +@@ -6090,7 +6877,7 @@ static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, + cp.targetLength = (unsigned)(-clampedCompressionLevel); + } + /* refine parameters based on srcSize & dictSize */ +- return ZSTD_adjustCParams_internal(cp, srcSizeHint, dictSize, mode); ++ return ZSTD_adjustCParams_internal(cp, srcSizeHint, dictSize, mode, ZSTD_ps_auto); + } + } + +@@ -6125,3 +6912,29 @@ ZSTD_parameters ZSTD_getParams(int compressionLevel, unsigned long long srcSizeH + if (srcSizeHint == 0) srcSizeHint = ZSTD_CONTENTSIZE_UNKNOWN; + return ZSTD_getParams_internal(compressionLevel, srcSizeHint, dictSize, ZSTD_cpm_unknown); + } ++ ++void ZSTD_registerSequenceProducer( ++ ZSTD_CCtx* zc, ++ void* extSeqProdState, ++ ZSTD_sequenceProducer_F extSeqProdFunc ++) { ++ assert(zc != NULL); ++ ZSTD_CCtxParams_registerSequenceProducer( ++ &zc->requestedParams, extSeqProdState, extSeqProdFunc ++ ); ++} ++ ++void ZSTD_CCtxParams_registerSequenceProducer( ++ ZSTD_CCtx_params* params, ++ void* extSeqProdState, ++ ZSTD_sequenceProducer_F extSeqProdFunc ++) { ++ assert(params != NULL); ++ if (extSeqProdFunc != NULL) { ++ params->extSeqProdFunc = extSeqProdFunc; ++ params->extSeqProdState = extSeqProdState; ++ } else { ++ params->extSeqProdFunc = NULL; ++ params->extSeqProdState = NULL; ++ } ++} +diff --git a/lib/zstd/compress/zstd_compress_internal.h b/lib/zstd/compress/zstd_compress_internal.h +index 71697a11ae30..53cb582a8d2b 100644 +--- a/lib/zstd/compress/zstd_compress_internal.h ++++ b/lib/zstd/compress/zstd_compress_internal.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -20,6 +21,7 @@ + ***************************************/ + #include "../common/zstd_internal.h" + #include "zstd_cwksp.h" ++#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_NbCommonBytes */ + + + /*-************************************* +@@ -32,7 +34,7 @@ + It's not a big deal though : candidate will just be sorted again. + Additionally, candidate position 1 will be lost. + But candidate 1 cannot hide a large tree of candidates, so it's a minimal loss. +- The benefit is that ZSTD_DUBT_UNSORTED_MARK cannot be mishandled after table re-use with a different strategy. ++ The benefit is that ZSTD_DUBT_UNSORTED_MARK cannot be mishandled after table reuse with a different strategy. + This constant is required by ZSTD_compressBlock_btlazy2() and ZSTD_reduceTable_internal() */ + + +@@ -111,12 +113,13 @@ typedef struct { + /* ZSTD_buildBlockEntropyStats() : + * Builds entropy for the block. + * @return : 0 on success or error code */ +-size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, +- const ZSTD_entropyCTables_t* prevEntropy, +- ZSTD_entropyCTables_t* nextEntropy, +- const ZSTD_CCtx_params* cctxParams, +- ZSTD_entropyCTablesMetadata_t* entropyMetadata, +- void* workspace, size_t wkspSize); ++size_t ZSTD_buildBlockEntropyStats( ++ const seqStore_t* seqStorePtr, ++ const ZSTD_entropyCTables_t* prevEntropy, ++ ZSTD_entropyCTables_t* nextEntropy, ++ const ZSTD_CCtx_params* cctxParams, ++ ZSTD_entropyCTablesMetadata_t* entropyMetadata, ++ void* workspace, size_t wkspSize); + + /* ******************************* + * Compression internals structs * +@@ -142,26 +145,33 @@ typedef struct { + size_t capacity; /* The capacity starting from `seq` pointer */ + } rawSeqStore_t; + ++typedef struct { ++ U32 idx; /* Index in array of ZSTD_Sequence */ ++ U32 posInSequence; /* Position within sequence at idx */ ++ size_t posInSrc; /* Number of bytes given by sequences provided so far */ ++} ZSTD_sequencePosition; ++ + UNUSED_ATTR static const rawSeqStore_t kNullRawSeqStore = {NULL, 0, 0, 0, 0}; + + typedef struct { +- int price; +- U32 off; +- U32 mlen; +- U32 litlen; +- U32 rep[ZSTD_REP_NUM]; ++ int price; /* price from beginning of segment to this position */ ++ U32 off; /* offset of previous match */ ++ U32 mlen; /* length of previous match */ ++ U32 litlen; /* nb of literals since previous match */ ++ U32 rep[ZSTD_REP_NUM]; /* offset history after previous match */ + } ZSTD_optimal_t; + + typedef enum { zop_dynamic=0, zop_predef } ZSTD_OptPrice_e; + ++#define ZSTD_OPT_SIZE (ZSTD_OPT_NUM+3) + typedef struct { + /* All tables are allocated inside cctx->workspace by ZSTD_resetCCtx_internal() */ + unsigned* litFreq; /* table of literals statistics, of size 256 */ + unsigned* litLengthFreq; /* table of litLength statistics, of size (MaxLL+1) */ + unsigned* matchLengthFreq; /* table of matchLength statistics, of size (MaxML+1) */ + unsigned* offCodeFreq; /* table of offCode statistics, of size (MaxOff+1) */ +- ZSTD_match_t* matchTable; /* list of found matches, of size ZSTD_OPT_NUM+1 */ +- ZSTD_optimal_t* priceTable; /* All positions tracked by optimal parser, of size ZSTD_OPT_NUM+1 */ ++ ZSTD_match_t* matchTable; /* list of found matches, of size ZSTD_OPT_SIZE */ ++ ZSTD_optimal_t* priceTable; /* All positions tracked by optimal parser, of size ZSTD_OPT_SIZE */ + + U32 litSum; /* nb of literals */ + U32 litLengthSum; /* nb of litLength codes */ +@@ -212,8 +222,10 @@ struct ZSTD_matchState_t { + U32 hashLog3; /* dispatch table for matches of len==3 : larger == faster, more memory */ + + U32 rowHashLog; /* For row-based matchfinder: Hashlog based on nb of rows in the hashTable.*/ +- U16* tagTable; /* For row-based matchFinder: A row-based table containing the hashes and head index. */ ++ BYTE* tagTable; /* For row-based matchFinder: A row-based table containing the hashes and head index. */ + U32 hashCache[ZSTD_ROW_HASH_CACHE_SIZE]; /* For row-based matchFinder: a cache of hashes to improve speed */ ++ U64 hashSalt; /* For row-based matchFinder: salts the hash for reuse of tag table */ ++ U32 hashSaltEntropy; /* For row-based matchFinder: collects entropy for salt generation */ + + U32* hashTable; + U32* hashTable3; +@@ -228,6 +240,18 @@ struct ZSTD_matchState_t { + const ZSTD_matchState_t* dictMatchState; + ZSTD_compressionParameters cParams; + const rawSeqStore_t* ldmSeqStore; ++ ++ /* Controls prefetching in some dictMatchState matchfinders. ++ * This behavior is controlled from the cctx ms. ++ * This parameter has no effect in the cdict ms. */ ++ int prefetchCDictTables; ++ ++ /* When == 0, lazy match finders insert every position. ++ * When != 0, lazy match finders only insert positions they search. ++ * This allows them to skip much faster over incompressible data, ++ * at a small cost to compression ratio. ++ */ ++ int lazySkipping; + }; + + typedef struct { +@@ -324,6 +348,25 @@ struct ZSTD_CCtx_params_s { + + /* Internal use, for createCCtxParams() and freeCCtxParams() only */ + ZSTD_customMem customMem; ++ ++ /* Controls prefetching in some dictMatchState matchfinders */ ++ ZSTD_paramSwitch_e prefetchCDictTables; ++ ++ /* Controls whether zstd will fall back to an internal matchfinder ++ * if the external matchfinder returns an error code. */ ++ int enableMatchFinderFallback; ++ ++ /* Parameters for the external sequence producer API. ++ * Users set these parameters through ZSTD_registerSequenceProducer(). ++ * It is not possible to set these parameters individually through the public API. */ ++ void* extSeqProdState; ++ ZSTD_sequenceProducer_F extSeqProdFunc; ++ ++ /* Adjust the max block size*/ ++ size_t maxBlockSize; ++ ++ /* Controls repcode search in external sequence parsing */ ++ ZSTD_paramSwitch_e searchForExternalRepcodes; + }; /* typedef'd to ZSTD_CCtx_params within "zstd.h" */ + + #define COMPRESS_SEQUENCES_WORKSPACE_SIZE (sizeof(unsigned) * (MaxSeq + 2)) +@@ -404,6 +447,7 @@ struct ZSTD_CCtx_s { + + /* Stable in/out buffer verification */ + ZSTD_inBuffer expectedInBuffer; ++ size_t stableIn_notConsumed; /* nb bytes within stable input buffer that are said to be consumed but are not */ + size_t expectedOutBufferSize; + + /* Dictionary */ +@@ -417,9 +461,14 @@ struct ZSTD_CCtx_s { + + /* Workspace for block splitter */ + ZSTD_blockSplitCtx blockSplitCtx; ++ ++ /* Buffer for output from external sequence producer */ ++ ZSTD_Sequence* extSeqBuf; ++ size_t extSeqBufCapacity; + }; + + typedef enum { ZSTD_dtlm_fast, ZSTD_dtlm_full } ZSTD_dictTableLoadMethod_e; ++typedef enum { ZSTD_tfp_forCCtx, ZSTD_tfp_forCDict } ZSTD_tableFillPurpose_e; + + typedef enum { + ZSTD_noDict = 0, +@@ -441,7 +490,7 @@ typedef enum { + * In this mode we take both the source size and the dictionary size + * into account when selecting and adjusting the parameters. + */ +- ZSTD_cpm_unknown = 3, /* ZSTD_getCParams, ZSTD_getParams, ZSTD_adjustParams. ++ ZSTD_cpm_unknown = 3 /* ZSTD_getCParams, ZSTD_getParams, ZSTD_adjustParams. + * We don't know what these parameters are for. We default to the legacy + * behavior of taking both the source size and the dict size into account + * when selecting and adjusting parameters. +@@ -500,9 +549,11 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value) + /* ZSTD_noCompressBlock() : + * Writes uncompressed block to dst buffer from given src. + * Returns the size of the block */ +-MEM_STATIC size_t ZSTD_noCompressBlock (void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) ++MEM_STATIC size_t ++ZSTD_noCompressBlock(void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) + { + U32 const cBlockHeader24 = lastBlock + (((U32)bt_raw)<<1) + (U32)(srcSize << 3); ++ DEBUGLOG(5, "ZSTD_noCompressBlock (srcSize=%zu, dstCapacity=%zu)", srcSize, dstCapacity); + RETURN_ERROR_IF(srcSize + ZSTD_blockHeaderSize > dstCapacity, + dstSize_tooSmall, "dst buf too small for uncompressed block"); + MEM_writeLE24(dst, cBlockHeader24); +@@ -510,7 +561,8 @@ MEM_STATIC size_t ZSTD_noCompressBlock (void* dst, size_t dstCapacity, const voi + return ZSTD_blockHeaderSize + srcSize; + } + +-MEM_STATIC size_t ZSTD_rleCompressBlock (void* dst, size_t dstCapacity, BYTE src, size_t srcSize, U32 lastBlock) ++MEM_STATIC size_t ++ZSTD_rleCompressBlock(void* dst, size_t dstCapacity, BYTE src, size_t srcSize, U32 lastBlock) + { + BYTE* const op = (BYTE*)dst; + U32 const cBlockHeader = lastBlock + (((U32)bt_rle)<<1) + (U32)(srcSize << 3); +@@ -529,7 +581,7 @@ MEM_STATIC size_t ZSTD_minGain(size_t srcSize, ZSTD_strategy strat) + { + U32 const minlog = (strat>=ZSTD_btultra) ? (U32)(strat) - 1 : 6; + ZSTD_STATIC_ASSERT(ZSTD_btultra == 8); +- assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, strat)); ++ assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, (int)strat)); + return (srcSize >> minlog) + 2; + } + +@@ -565,29 +617,27 @@ ZSTD_safecopyLiterals(BYTE* op, BYTE const* ip, BYTE const* const iend, BYTE con + while (ip < iend) *op++ = *ip++; + } + +-#define ZSTD_REP_MOVE (ZSTD_REP_NUM-1) +-#define STORE_REPCODE_1 STORE_REPCODE(1) +-#define STORE_REPCODE_2 STORE_REPCODE(2) +-#define STORE_REPCODE_3 STORE_REPCODE(3) +-#define STORE_REPCODE(r) (assert((r)>=1), assert((r)<=3), (r)-1) +-#define STORE_OFFSET(o) (assert((o)>0), o + ZSTD_REP_MOVE) +-#define STORED_IS_OFFSET(o) ((o) > ZSTD_REP_MOVE) +-#define STORED_IS_REPCODE(o) ((o) <= ZSTD_REP_MOVE) +-#define STORED_OFFSET(o) (assert(STORED_IS_OFFSET(o)), (o)-ZSTD_REP_MOVE) +-#define STORED_REPCODE(o) (assert(STORED_IS_REPCODE(o)), (o)+1) /* returns ID 1,2,3 */ +-#define STORED_TO_OFFBASE(o) ((o)+1) +-#define OFFBASE_TO_STORED(o) ((o)-1) ++ ++#define REPCODE1_TO_OFFBASE REPCODE_TO_OFFBASE(1) ++#define REPCODE2_TO_OFFBASE REPCODE_TO_OFFBASE(2) ++#define REPCODE3_TO_OFFBASE REPCODE_TO_OFFBASE(3) ++#define REPCODE_TO_OFFBASE(r) (assert((r)>=1), assert((r)<=ZSTD_REP_NUM), (r)) /* accepts IDs 1,2,3 */ ++#define OFFSET_TO_OFFBASE(o) (assert((o)>0), o + ZSTD_REP_NUM) ++#define OFFBASE_IS_OFFSET(o) ((o) > ZSTD_REP_NUM) ++#define OFFBASE_IS_REPCODE(o) ( 1 <= (o) && (o) <= ZSTD_REP_NUM) ++#define OFFBASE_TO_OFFSET(o) (assert(OFFBASE_IS_OFFSET(o)), (o) - ZSTD_REP_NUM) ++#define OFFBASE_TO_REPCODE(o) (assert(OFFBASE_IS_REPCODE(o)), (o)) /* returns ID 1,2,3 */ + + /*! ZSTD_storeSeq() : +- * Store a sequence (litlen, litPtr, offCode and matchLength) into seqStore_t. +- * @offBase_minus1 : Users should use employ macros STORE_REPCODE_X and STORE_OFFSET(). ++ * Store a sequence (litlen, litPtr, offBase and matchLength) into seqStore_t. ++ * @offBase : Users should employ macros REPCODE_TO_OFFBASE() and OFFSET_TO_OFFBASE(). + * @matchLength : must be >= MINMATCH +- * Allowed to overread literals up to litLimit. ++ * Allowed to over-read literals up to litLimit. + */ + HINT_INLINE UNUSED_ATTR void + ZSTD_storeSeq(seqStore_t* seqStorePtr, + size_t litLength, const BYTE* literals, const BYTE* litLimit, +- U32 offBase_minus1, ++ U32 offBase, + size_t matchLength) + { + BYTE const* const litLimit_w = litLimit - WILDCOPY_OVERLENGTH; +@@ -596,8 +646,8 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, + static const BYTE* g_start = NULL; + if (g_start==NULL) g_start = (const BYTE*)literals; /* note : index only works for compression within a single segment */ + { U32 const pos = (U32)((const BYTE*)literals - g_start); +- DEBUGLOG(6, "Cpos%7u :%3u literals, match%4u bytes at offCode%7u", +- pos, (U32)litLength, (U32)matchLength, (U32)offBase_minus1); ++ DEBUGLOG(6, "Cpos%7u :%3u literals, match%4u bytes at offBase%7u", ++ pos, (U32)litLength, (U32)matchLength, (U32)offBase); + } + #endif + assert((size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart) < seqStorePtr->maxNbSeq); +@@ -607,9 +657,9 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, + assert(literals + litLength <= litLimit); + if (litEnd <= litLimit_w) { + /* Common case we can use wildcopy. +- * First copy 16 bytes, because literals are likely short. +- */ +- assert(WILDCOPY_OVERLENGTH >= 16); ++ * First copy 16 bytes, because literals are likely short. ++ */ ++ ZSTD_STATIC_ASSERT(WILDCOPY_OVERLENGTH >= 16); + ZSTD_copy16(seqStorePtr->lit, literals); + if (litLength > 16) { + ZSTD_wildcopy(seqStorePtr->lit+16, literals+16, (ptrdiff_t)litLength-16, ZSTD_no_overlap); +@@ -628,7 +678,7 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, + seqStorePtr->sequences[0].litLength = (U16)litLength; + + /* match offset */ +- seqStorePtr->sequences[0].offBase = STORED_TO_OFFBASE(offBase_minus1); ++ seqStorePtr->sequences[0].offBase = offBase; + + /* match Length */ + assert(matchLength >= MINMATCH); +@@ -646,17 +696,17 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, + + /* ZSTD_updateRep() : + * updates in-place @rep (array of repeat offsets) +- * @offBase_minus1 : sum-type, with same numeric representation as ZSTD_storeSeq() ++ * @offBase : sum-type, using numeric representation of ZSTD_storeSeq() + */ + MEM_STATIC void +-ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) ++ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase, U32 const ll0) + { +- if (STORED_IS_OFFSET(offBase_minus1)) { /* full offset */ ++ if (OFFBASE_IS_OFFSET(offBase)) { /* full offset */ + rep[2] = rep[1]; + rep[1] = rep[0]; +- rep[0] = STORED_OFFSET(offBase_minus1); ++ rep[0] = OFFBASE_TO_OFFSET(offBase); + } else { /* repcode */ +- U32 const repCode = STORED_REPCODE(offBase_minus1) - 1 + ll0; ++ U32 const repCode = OFFBASE_TO_REPCODE(offBase) - 1 + ll0; + if (repCode > 0) { /* note : if repCode==0, no change */ + U32 const currentOffset = (repCode==ZSTD_REP_NUM) ? (rep[0] - 1) : rep[repCode]; + rep[2] = (repCode >= 2) ? rep[1] : rep[2]; +@@ -673,11 +723,11 @@ typedef struct repcodes_s { + } repcodes_t; + + MEM_STATIC repcodes_t +-ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) ++ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase, U32 const ll0) + { + repcodes_t newReps; + ZSTD_memcpy(&newReps, rep, sizeof(newReps)); +- ZSTD_updateRep(newReps.rep, offBase_minus1, ll0); ++ ZSTD_updateRep(newReps.rep, offBase, ll0); + return newReps; + } + +@@ -685,59 +735,6 @@ ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0 + /*-************************************* + * Match length counter + ***************************************/ +-static unsigned ZSTD_NbCommonBytes (size_t val) +-{ +- if (MEM_isLittleEndian()) { +- if (MEM_64bits()) { +-# if (__GNUC__ >= 4) +- return (__builtin_ctzll((U64)val) >> 3); +-# else +- static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, +- 0, 3, 1, 3, 1, 4, 2, 7, +- 0, 2, 3, 6, 1, 5, 3, 5, +- 1, 3, 4, 4, 2, 5, 6, 7, +- 7, 0, 1, 2, 3, 3, 4, 6, +- 2, 6, 5, 5, 3, 4, 5, 6, +- 7, 1, 2, 4, 6, 4, 4, 5, +- 7, 2, 6, 5, 7, 6, 7, 7 }; +- return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; +-# endif +- } else { /* 32 bits */ +-# if (__GNUC__ >= 3) +- return (__builtin_ctz((U32)val) >> 3); +-# else +- static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, +- 3, 2, 2, 1, 3, 2, 0, 1, +- 3, 3, 1, 2, 2, 2, 2, 0, +- 3, 1, 2, 0, 1, 0, 1, 1 }; +- return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; +-# endif +- } +- } else { /* Big Endian CPU */ +- if (MEM_64bits()) { +-# if (__GNUC__ >= 4) +- return (__builtin_clzll(val) >> 3); +-# else +- unsigned r; +- const unsigned n32 = sizeof(size_t)*4; /* calculate this way due to compiler complaining in 32-bits mode */ +- if (!(val>>n32)) { r=4; } else { r=0; val>>=n32; } +- if (!(val>>16)) { r+=2; val>>=8; } else { val>>=24; } +- r += (!val); +- return r; +-# endif +- } else { /* 32 bits */ +-# if (__GNUC__ >= 3) +- return (__builtin_clz((U32)val) >> 3); +-# else +- unsigned r; +- if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; } +- r += (!val); +- return r; +-# endif +- } } +-} +- +- + MEM_STATIC size_t ZSTD_count(const BYTE* pIn, const BYTE* pMatch, const BYTE* const pInLimit) + { + const BYTE* const pStart = pIn; +@@ -783,32 +780,43 @@ ZSTD_count_2segments(const BYTE* ip, const BYTE* match, + * Hashes + ***************************************/ + static const U32 prime3bytes = 506832829U; +-static U32 ZSTD_hash3(U32 u, U32 h) { return ((u << (32-24)) * prime3bytes) >> (32-h) ; } +-MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h); } /* only in zstd_opt.h */ ++static U32 ZSTD_hash3(U32 u, U32 h, U32 s) { assert(h <= 32); return (((u << (32-24)) * prime3bytes) ^ s) >> (32-h) ; } ++MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h, 0); } /* only in zstd_opt.h */ ++MEM_STATIC size_t ZSTD_hash3PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash3(MEM_readLE32(ptr), h, s); } + + static const U32 prime4bytes = 2654435761U; +-static U32 ZSTD_hash4(U32 u, U32 h) { return (u * prime4bytes) >> (32-h) ; } +-static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_read32(ptr), h); } ++static U32 ZSTD_hash4(U32 u, U32 h, U32 s) { assert(h <= 32); return ((u * prime4bytes) ^ s) >> (32-h) ; } ++static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_readLE32(ptr), h, 0); } ++static size_t ZSTD_hash4PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash4(MEM_readLE32(ptr), h, s); } + + static const U64 prime5bytes = 889523592379ULL; +-static size_t ZSTD_hash5(U64 u, U32 h) { return (size_t)(((u << (64-40)) * prime5bytes) >> (64-h)) ; } +-static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h); } ++static size_t ZSTD_hash5(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-40)) * prime5bytes) ^ s) >> (64-h)) ; } ++static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h, 0); } ++static size_t ZSTD_hash5PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash5(MEM_readLE64(p), h, s); } + + static const U64 prime6bytes = 227718039650203ULL; +-static size_t ZSTD_hash6(U64 u, U32 h) { return (size_t)(((u << (64-48)) * prime6bytes) >> (64-h)) ; } +-static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h); } ++static size_t ZSTD_hash6(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-48)) * prime6bytes) ^ s) >> (64-h)) ; } ++static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h, 0); } ++static size_t ZSTD_hash6PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash6(MEM_readLE64(p), h, s); } + + static const U64 prime7bytes = 58295818150454627ULL; +-static size_t ZSTD_hash7(U64 u, U32 h) { return (size_t)(((u << (64-56)) * prime7bytes) >> (64-h)) ; } +-static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h); } ++static size_t ZSTD_hash7(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-56)) * prime7bytes) ^ s) >> (64-h)) ; } ++static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h, 0); } ++static size_t ZSTD_hash7PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash7(MEM_readLE64(p), h, s); } + + static const U64 prime8bytes = 0xCF1BBCDCB7A56463ULL; +-static size_t ZSTD_hash8(U64 u, U32 h) { return (size_t)(((u) * prime8bytes) >> (64-h)) ; } +-static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h); } ++static size_t ZSTD_hash8(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u) * prime8bytes) ^ s) >> (64-h)) ; } ++static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h, 0); } ++static size_t ZSTD_hash8PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash8(MEM_readLE64(p), h, s); } ++ + + MEM_STATIC FORCE_INLINE_ATTR + size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls) + { ++ /* Although some of these hashes do support hBits up to 64, some do not. ++ * To be on the safe side, always avoid hBits > 32. */ ++ assert(hBits <= 32); ++ + switch(mls) + { + default: +@@ -820,6 +828,24 @@ size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls) + } + } + ++MEM_STATIC FORCE_INLINE_ATTR ++size_t ZSTD_hashPtrSalted(const void* p, U32 hBits, U32 mls, const U64 hashSalt) { ++ /* Although some of these hashes do support hBits up to 64, some do not. ++ * To be on the safe side, always avoid hBits > 32. */ ++ assert(hBits <= 32); ++ ++ switch(mls) ++ { ++ default: ++ case 4: return ZSTD_hash4PtrS(p, hBits, (U32)hashSalt); ++ case 5: return ZSTD_hash5PtrS(p, hBits, hashSalt); ++ case 6: return ZSTD_hash6PtrS(p, hBits, hashSalt); ++ case 7: return ZSTD_hash7PtrS(p, hBits, hashSalt); ++ case 8: return ZSTD_hash8PtrS(p, hBits, hashSalt); ++ } ++} ++ ++ + /* ZSTD_ipow() : + * Return base^exponent. + */ +@@ -1011,7 +1037,9 @@ MEM_STATIC U32 ZSTD_window_needOverflowCorrection(ZSTD_window_t const window, + * The least significant cycleLog bits of the indices must remain the same, + * which may be 0. Every index up to maxDist in the past must be valid. + */ +-MEM_STATIC U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, ++MEM_STATIC ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, + U32 maxDist, void const* src) + { + /* preemptive overflow correction: +@@ -1167,10 +1195,15 @@ ZSTD_checkDictValidity(const ZSTD_window_t* window, + (unsigned)blockEndIdx, (unsigned)maxDist, (unsigned)loadedDictEnd); + assert(blockEndIdx >= loadedDictEnd); + +- if (blockEndIdx > loadedDictEnd + maxDist) { ++ if (blockEndIdx > loadedDictEnd + maxDist || loadedDictEnd != window->dictLimit) { + /* On reaching window size, dictionaries are invalidated. + * For simplification, if window size is reached anywhere within next block, + * the dictionary is invalidated for the full block. ++ * ++ * We also have to invalidate the dictionary if ZSTD_window_update() has detected ++ * non-contiguous segments, which means that loadedDictEnd != window->dictLimit. ++ * loadedDictEnd may be 0, if forceWindow is true, but in that case we never use ++ * dictMatchState, so setting it to NULL is not a problem. + */ + DEBUGLOG(6, "invalidating dictionary for current block (distance > windowSize)"); + *loadedDictEndPtr = 0; +@@ -1199,7 +1232,9 @@ MEM_STATIC void ZSTD_window_init(ZSTD_window_t* window) { + * forget about the extDict. Handles overlap of the prefix and extDict. + * Returns non-zero if the segment is contiguous. + */ +-MEM_STATIC U32 ZSTD_window_update(ZSTD_window_t* window, ++MEM_STATIC ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_window_update(ZSTD_window_t* window, + void const* src, size_t srcSize, + int forceNonContiguous) + { +@@ -1302,6 +1337,42 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max) + + #endif + ++/* Short Cache */ ++ ++/* Normally, zstd matchfinders follow this flow: ++ * 1. Compute hash at ip ++ * 2. Load index from hashTable[hash] ++ * 3. Check if *ip == *(base + index) ++ * In dictionary compression, loading *(base + index) is often an L2 or even L3 miss. ++ * ++ * Short cache is an optimization which allows us to avoid step 3 most of the time ++ * when the data doesn't actually match. With short cache, the flow becomes: ++ * 1. Compute (hash, currentTag) at ip. currentTag is an 8-bit independent hash at ip. ++ * 2. Load (index, matchTag) from hashTable[hash]. See ZSTD_writeTaggedIndex to understand how this works. ++ * 3. Only if currentTag == matchTag, check *ip == *(base + index). Otherwise, continue. ++ * ++ * Currently, short cache is only implemented in CDict hashtables. Thus, its use is limited to ++ * dictMatchState matchfinders. ++ */ ++#define ZSTD_SHORT_CACHE_TAG_BITS 8 ++#define ZSTD_SHORT_CACHE_TAG_MASK ((1u << ZSTD_SHORT_CACHE_TAG_BITS) - 1) ++ ++/* Helper function for ZSTD_fillHashTable and ZSTD_fillDoubleHashTable. ++ * Unpacks hashAndTag into (hash, tag), then packs (index, tag) into hashTable[hash]. */ ++MEM_STATIC void ZSTD_writeTaggedIndex(U32* const hashTable, size_t hashAndTag, U32 index) { ++ size_t const hash = hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS; ++ U32 const tag = (U32)(hashAndTag & ZSTD_SHORT_CACHE_TAG_MASK); ++ assert(index >> (32 - ZSTD_SHORT_CACHE_TAG_BITS) == 0); ++ hashTable[hash] = (index << ZSTD_SHORT_CACHE_TAG_BITS) | tag; ++} ++ ++/* Helper function for short cache matchfinders. ++ * Unpacks tag1 and tag2 from lower bits of packedTag1 and packedTag2, then checks if the tags match. */ ++MEM_STATIC int ZSTD_comparePackedTags(size_t packedTag1, size_t packedTag2) { ++ U32 const tag1 = packedTag1 & ZSTD_SHORT_CACHE_TAG_MASK; ++ U32 const tag2 = packedTag2 & ZSTD_SHORT_CACHE_TAG_MASK; ++ return tag1 == tag2; ++} + + + /* =============================================================== +@@ -1381,11 +1452,10 @@ size_t ZSTD_writeLastEmptyBlock(void* dst, size_t dstCapacity); + * This cannot be used when long range matching is enabled. + * Zstd will use these sequences, and pass the literals to a secondary block + * compressor. +- * @return : An error code on failure. + * NOTE: seqs are not verified! Invalid sequences can cause out-of-bounds memory + * access and data corruption. + */ +-size_t ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq); ++void ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq); + + /* ZSTD_cycleLog() : + * condition for correct operation : hashLog > 1 */ +@@ -1396,4 +1466,55 @@ U32 ZSTD_cycleLog(U32 hashLog, ZSTD_strategy strat); + */ + void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize); + ++/* Returns 0 on success, and a ZSTD_error otherwise. This function scans through an array of ++ * ZSTD_Sequence, storing the sequences it finds, until it reaches a block delimiter. ++ * Note that the block delimiter must include the last literals of the block. ++ */ ++size_t ++ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, ++ ZSTD_sequencePosition* seqPos, ++ const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, ++ const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch); ++ ++/* Returns the number of bytes to move the current read position back by. ++ * Only non-zero if we ended up splitting a sequence. ++ * Otherwise, it may return a ZSTD error if something went wrong. ++ * ++ * This function will attempt to scan through blockSize bytes ++ * represented by the sequences in @inSeqs, ++ * storing any (partial) sequences. ++ * ++ * Occasionally, we may want to change the actual number of bytes we consumed from inSeqs to ++ * avoid splitting a match, or to avoid splitting a match such that it would produce a match ++ * smaller than MINMATCH. In this case, we return the number of bytes that we didn't read from this block. ++ */ ++size_t ++ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, ++ const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, ++ const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch); ++ ++/* Returns 1 if an external sequence producer is registered, otherwise returns 0. */ ++MEM_STATIC int ZSTD_hasExtSeqProd(const ZSTD_CCtx_params* params) { ++ return params->extSeqProdFunc != NULL; ++} ++ ++/* =============================================================== ++ * Deprecated definitions that are still used internally to avoid ++ * deprecation warnings. These functions are exactly equivalent to ++ * their public variants, but avoid the deprecation warnings. ++ * =============================================================== */ ++ ++size_t ZSTD_compressBegin_usingCDict_deprecated(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); ++ ++size_t ZSTD_compressContinue_public(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize); ++ ++size_t ZSTD_compressEnd_public(ZSTD_CCtx* cctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize); ++ ++size_t ZSTD_compressBlock_deprecated(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); ++ ++ + #endif /* ZSTD_COMPRESS_H */ +diff --git a/lib/zstd/compress/zstd_compress_literals.c b/lib/zstd/compress/zstd_compress_literals.c +index 52b0a8059aba..3e9ea46a670a 100644 +--- a/lib/zstd/compress/zstd_compress_literals.c ++++ b/lib/zstd/compress/zstd_compress_literals.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -13,11 +14,36 @@ + ***************************************/ + #include "zstd_compress_literals.h" + ++ ++/* ************************************************************** ++* Debug Traces ++****************************************************************/ ++#if DEBUGLEVEL >= 2 ++ ++static size_t showHexa(const void* src, size_t srcSize) ++{ ++ const BYTE* const ip = (const BYTE*)src; ++ size_t u; ++ for (u=0; u31) + (srcSize>4095); + ++ DEBUGLOG(5, "ZSTD_noCompressLiterals: srcSize=%zu, dstCapacity=%zu", srcSize, dstCapacity); ++ + RETURN_ERROR_IF(srcSize + flSize > dstCapacity, dstSize_tooSmall, ""); + + switch(flSize) +@@ -36,16 +62,30 @@ size_t ZSTD_noCompressLiterals (void* dst, size_t dstCapacity, const void* src, + } + + ZSTD_memcpy(ostart + flSize, src, srcSize); +- DEBUGLOG(5, "Raw literals: %u -> %u", (U32)srcSize, (U32)(srcSize + flSize)); ++ DEBUGLOG(5, "Raw (uncompressed) literals: %u -> %u", (U32)srcSize, (U32)(srcSize + flSize)); + return srcSize + flSize; + } + ++static int allBytesIdentical(const void* src, size_t srcSize) ++{ ++ assert(srcSize >= 1); ++ assert(src != NULL); ++ { const BYTE b = ((const BYTE*)src)[0]; ++ size_t p; ++ for (p=1; p31) + (srcSize>4095); + +- (void)dstCapacity; /* dstCapacity already guaranteed to be >=4, hence large enough */ ++ assert(dstCapacity >= 4); (void)dstCapacity; ++ assert(allBytesIdentical(src, srcSize)); + + switch(flSize) + { +@@ -63,28 +103,51 @@ size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t dstCapacity, const void* + } + + ostart[flSize] = *(const BYTE*)src; +- DEBUGLOG(5, "RLE literals: %u -> %u", (U32)srcSize, (U32)flSize + 1); ++ DEBUGLOG(5, "RLE : Repeated Literal (%02X: %u times) -> %u bytes encoded", ((const BYTE*)src)[0], (U32)srcSize, (U32)flSize + 1); + return flSize+1; + } + +-size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, +- ZSTD_hufCTables_t* nextHuf, +- ZSTD_strategy strategy, int disableLiteralCompression, +- void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, +- void* entropyWorkspace, size_t entropyWorkspaceSize, +- const int bmi2, +- unsigned suspectUncompressible) ++/* ZSTD_minLiteralsToCompress() : ++ * returns minimal amount of literals ++ * for literal compression to even be attempted. ++ * Minimum is made tighter as compression strategy increases. ++ */ ++static size_t ++ZSTD_minLiteralsToCompress(ZSTD_strategy strategy, HUF_repeat huf_repeat) ++{ ++ assert((int)strategy >= 0); ++ assert((int)strategy <= 9); ++ /* btultra2 : min 8 bytes; ++ * then 2x larger for each successive compression strategy ++ * max threshold 64 bytes */ ++ { int const shift = MIN(9-(int)strategy, 3); ++ size_t const mintc = (huf_repeat == HUF_repeat_valid) ? 6 : (size_t)8 << shift; ++ DEBUGLOG(7, "minLiteralsToCompress = %zu", mintc); ++ return mintc; ++ } ++} ++ ++size_t ZSTD_compressLiterals ( ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize, ++ void* entropyWorkspace, size_t entropyWorkspaceSize, ++ const ZSTD_hufCTables_t* prevHuf, ++ ZSTD_hufCTables_t* nextHuf, ++ ZSTD_strategy strategy, ++ int disableLiteralCompression, ++ int suspectUncompressible, ++ int bmi2) + { +- size_t const minGain = ZSTD_minGain(srcSize, strategy); + size_t const lhSize = 3 + (srcSize >= 1 KB) + (srcSize >= 16 KB); + BYTE* const ostart = (BYTE*)dst; + U32 singleStream = srcSize < 256; + symbolEncodingType_e hType = set_compressed; + size_t cLitSize; + +- DEBUGLOG(5,"ZSTD_compressLiterals (disableLiteralCompression=%i srcSize=%u)", +- disableLiteralCompression, (U32)srcSize); ++ DEBUGLOG(5,"ZSTD_compressLiterals (disableLiteralCompression=%i, srcSize=%u, dstCapacity=%zu)", ++ disableLiteralCompression, (U32)srcSize, dstCapacity); ++ ++ DEBUGLOG(6, "Completed literals listing (%zu bytes)", showHexa(src, srcSize)); + + /* Prepare nextEntropy assuming reusing the existing table */ + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); +@@ -92,40 +155,51 @@ size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, + if (disableLiteralCompression) + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); + +- /* small ? don't even attempt compression (speed opt) */ +-# define COMPRESS_LITERALS_SIZE_MIN 63 +- { size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN; +- if (srcSize <= minLitSize) return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); +- } ++ /* if too small, don't even attempt compression (speed opt) */ ++ if (srcSize < ZSTD_minLiteralsToCompress(strategy, prevHuf->repeatMode)) ++ return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); + + RETURN_ERROR_IF(dstCapacity < lhSize+1, dstSize_tooSmall, "not enough space for compression"); + { HUF_repeat repeat = prevHuf->repeatMode; +- int const preferRepeat = strategy < ZSTD_lazy ? srcSize <= 1024 : 0; ++ int const flags = 0 ++ | (bmi2 ? HUF_flags_bmi2 : 0) ++ | (strategy < ZSTD_lazy && srcSize <= 1024 ? HUF_flags_preferRepeat : 0) ++ | (strategy >= HUF_OPTIMAL_DEPTH_THRESHOLD ? HUF_flags_optimalDepth : 0) ++ | (suspectUncompressible ? HUF_flags_suspectUncompressible : 0); ++ ++ typedef size_t (*huf_compress_f)(void*, size_t, const void*, size_t, unsigned, unsigned, void*, size_t, HUF_CElt*, HUF_repeat*, int); ++ huf_compress_f huf_compress; + if (repeat == HUF_repeat_valid && lhSize == 3) singleStream = 1; +- cLitSize = singleStream ? +- HUF_compress1X_repeat( +- ostart+lhSize, dstCapacity-lhSize, src, srcSize, +- HUF_SYMBOLVALUE_MAX, HUF_TABLELOG_DEFAULT, entropyWorkspace, entropyWorkspaceSize, +- (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2, suspectUncompressible) : +- HUF_compress4X_repeat( +- ostart+lhSize, dstCapacity-lhSize, src, srcSize, +- HUF_SYMBOLVALUE_MAX, HUF_TABLELOG_DEFAULT, entropyWorkspace, entropyWorkspaceSize, +- (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2, suspectUncompressible); ++ huf_compress = singleStream ? HUF_compress1X_repeat : HUF_compress4X_repeat; ++ cLitSize = huf_compress(ostart+lhSize, dstCapacity-lhSize, ++ src, srcSize, ++ HUF_SYMBOLVALUE_MAX, LitHufLog, ++ entropyWorkspace, entropyWorkspaceSize, ++ (HUF_CElt*)nextHuf->CTable, ++ &repeat, flags); ++ DEBUGLOG(5, "%zu literals compressed into %zu bytes (before header)", srcSize, cLitSize); + if (repeat != HUF_repeat_none) { + /* reused the existing table */ +- DEBUGLOG(5, "Reusing previous huffman table"); ++ DEBUGLOG(5, "reusing statistics from previous huffman block"); + hType = set_repeat; + } + } + +- if ((cLitSize==0) || (cLitSize >= srcSize - minGain) || ERR_isError(cLitSize)) { +- ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); +- return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); +- } ++ { size_t const minGain = ZSTD_minGain(srcSize, strategy); ++ if ((cLitSize==0) || (cLitSize >= srcSize - minGain) || ERR_isError(cLitSize)) { ++ ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); ++ return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); ++ } } + if (cLitSize==1) { +- ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); +- return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); +- } ++ /* A return value of 1 signals that the alphabet consists of a single symbol. ++ * However, in some rare circumstances, it could be the compressed size (a single byte). ++ * For that outcome to have a chance to happen, it's necessary that `srcSize < 8`. ++ * (it's also necessary to not generate statistics). ++ * Therefore, in such a case, actively check that all bytes are identical. */ ++ if ((srcSize >= 8) || allBytesIdentical(src, srcSize)) { ++ ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); ++ return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); ++ } } + + if (hType == set_compressed) { + /* using a newly constructed table */ +@@ -136,16 +210,19 @@ size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, + switch(lhSize) + { + case 3: /* 2 - 2 - 10 - 10 */ +- { U32 const lhc = hType + ((!singleStream) << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<14); ++ if (!singleStream) assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); ++ { U32 const lhc = hType + ((U32)(!singleStream) << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<14); + MEM_writeLE24(ostart, lhc); + break; + } + case 4: /* 2 - 2 - 14 - 14 */ ++ assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); + { U32 const lhc = hType + (2 << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<18); + MEM_writeLE32(ostart, lhc); + break; + } + case 5: /* 2 - 2 - 18 - 18 */ ++ assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); + { U32 const lhc = hType + (3 << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<22); + MEM_writeLE32(ostart, lhc); + ostart[4] = (BYTE)(cLitSize >> 10); +diff --git a/lib/zstd/compress/zstd_compress_literals.h b/lib/zstd/compress/zstd_compress_literals.h +index 9775fb97cb70..a2a85d6b69e5 100644 +--- a/lib/zstd/compress/zstd_compress_literals.h ++++ b/lib/zstd/compress/zstd_compress_literals.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -16,16 +17,24 @@ + + size_t ZSTD_noCompressLiterals (void* dst, size_t dstCapacity, const void* src, size_t srcSize); + ++/* ZSTD_compressRleLiteralsBlock() : ++ * Conditions : ++ * - All bytes in @src are identical ++ * - dstCapacity >= 4 */ + size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t dstCapacity, const void* src, size_t srcSize); + +-/* If suspectUncompressible then some sampling checks will be run to potentially skip huffman coding */ +-size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, +- ZSTD_hufCTables_t* nextHuf, +- ZSTD_strategy strategy, int disableLiteralCompression, +- void* dst, size_t dstCapacity, ++/* ZSTD_compressLiterals(): ++ * @entropyWorkspace: must be aligned on 4-bytes boundaries ++ * @entropyWorkspaceSize : must be >= HUF_WORKSPACE_SIZE ++ * @suspectUncompressible: sampling checks, to potentially skip huffman coding ++ */ ++size_t ZSTD_compressLiterals (void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + void* entropyWorkspace, size_t entropyWorkspaceSize, +- const int bmi2, +- unsigned suspectUncompressible); ++ const ZSTD_hufCTables_t* prevHuf, ++ ZSTD_hufCTables_t* nextHuf, ++ ZSTD_strategy strategy, int disableLiteralCompression, ++ int suspectUncompressible, ++ int bmi2); + + #endif /* ZSTD_COMPRESS_LITERALS_H */ +diff --git a/lib/zstd/compress/zstd_compress_sequences.c b/lib/zstd/compress/zstd_compress_sequences.c +index 21ddc1b37acf..5c028c78d889 100644 +--- a/lib/zstd/compress/zstd_compress_sequences.c ++++ b/lib/zstd/compress/zstd_compress_sequences.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -58,7 +59,7 @@ static unsigned ZSTD_useLowProbCount(size_t const nbSeq) + { + /* Heuristic: This should cover most blocks <= 16K and + * start to fade out after 16K to about 32K depending on +- * comprssibility. ++ * compressibility. + */ + return nbSeq >= 2048; + } +@@ -166,7 +167,7 @@ ZSTD_selectEncodingType( + if (mostFrequent == nbSeq) { + *repeatMode = FSE_repeat_none; + if (isDefaultAllowed && nbSeq <= 2) { +- /* Prefer set_basic over set_rle when there are 2 or less symbols, ++ /* Prefer set_basic over set_rle when there are 2 or fewer symbols, + * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol. + * If basic encoding isn't possible, always choose RLE. + */ +diff --git a/lib/zstd/compress/zstd_compress_sequences.h b/lib/zstd/compress/zstd_compress_sequences.h +index 7991364c2f71..7fe6f4ff5cf2 100644 +--- a/lib/zstd/compress/zstd_compress_sequences.h ++++ b/lib/zstd/compress/zstd_compress_sequences.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/compress/zstd_compress_superblock.c b/lib/zstd/compress/zstd_compress_superblock.c +index 17d836cc84e8..41f6521b27cd 100644 +--- a/lib/zstd/compress/zstd_compress_superblock.c ++++ b/lib/zstd/compress/zstd_compress_superblock.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -36,13 +37,14 @@ + * If it is set_compressed, first sub-block's literals section will be Treeless_Literals_Block + * and the following sub-blocks' literals sections will be Treeless_Literals_Block. + * @return : compressed size of literals section of a sub-block +- * Or 0 if it unable to compress. ++ * Or 0 if unable to compress. + * Or error code */ +-static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, +- const ZSTD_hufCTablesMetadata_t* hufMetadata, +- const BYTE* literals, size_t litSize, +- void* dst, size_t dstSize, +- const int bmi2, int writeEntropy, int* entropyWritten) ++static size_t ++ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, ++ const ZSTD_hufCTablesMetadata_t* hufMetadata, ++ const BYTE* literals, size_t litSize, ++ void* dst, size_t dstSize, ++ const int bmi2, int writeEntropy, int* entropyWritten) + { + size_t const header = writeEntropy ? 200 : 0; + size_t const lhSize = 3 + (litSize >= (1 KB - header)) + (litSize >= (16 KB - header)); +@@ -53,8 +55,6 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, + symbolEncodingType_e hType = writeEntropy ? hufMetadata->hType : set_repeat; + size_t cLitSize = 0; + +- (void)bmi2; /* TODO bmi2... */ +- + DEBUGLOG(5, "ZSTD_compressSubBlock_literal (litSize=%zu, lhSize=%zu, writeEntropy=%d)", litSize, lhSize, writeEntropy); + + *entropyWritten = 0; +@@ -76,9 +76,9 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, + DEBUGLOG(5, "ZSTD_compressSubBlock_literal (hSize=%zu)", hufMetadata->hufDesSize); + } + +- /* TODO bmi2 */ +- { const size_t cSize = singleStream ? HUF_compress1X_usingCTable(op, oend-op, literals, litSize, hufTable) +- : HUF_compress4X_usingCTable(op, oend-op, literals, litSize, hufTable); ++ { int const flags = bmi2 ? HUF_flags_bmi2 : 0; ++ const size_t cSize = singleStream ? HUF_compress1X_usingCTable(op, (size_t)(oend-op), literals, litSize, hufTable, flags) ++ : HUF_compress4X_usingCTable(op, (size_t)(oend-op), literals, litSize, hufTable, flags); + op += cSize; + cLitSize += cSize; + if (cSize == 0 || ERR_isError(cSize)) { +@@ -103,7 +103,7 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, + switch(lhSize) + { + case 3: /* 2 - 2 - 10 - 10 */ +- { U32 const lhc = hType + ((!singleStream) << 2) + ((U32)litSize<<4) + ((U32)cLitSize<<14); ++ { U32 const lhc = hType + ((U32)(!singleStream) << 2) + ((U32)litSize<<4) + ((U32)cLitSize<<14); + MEM_writeLE24(ostart, lhc); + break; + } +@@ -123,26 +123,30 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, + } + *entropyWritten = 1; + DEBUGLOG(5, "Compressed literals: %u -> %u", (U32)litSize, (U32)(op-ostart)); +- return op-ostart; ++ return (size_t)(op-ostart); + } + +-static size_t ZSTD_seqDecompressedSize(seqStore_t const* seqStore, const seqDef* sequences, size_t nbSeq, size_t litSize, int lastSequence) { +- const seqDef* const sstart = sequences; +- const seqDef* const send = sequences + nbSeq; +- const seqDef* sp = sstart; ++static size_t ++ZSTD_seqDecompressedSize(seqStore_t const* seqStore, ++ const seqDef* sequences, size_t nbSeqs, ++ size_t litSize, int lastSubBlock) ++{ + size_t matchLengthSum = 0; + size_t litLengthSum = 0; +- (void)(litLengthSum); /* suppress unused variable warning on some environments */ +- while (send-sp > 0) { +- ZSTD_sequenceLength const seqLen = ZSTD_getSequenceLength(seqStore, sp); ++ size_t n; ++ for (n=0; ncParams.windowLog > STREAM_ACCUMULATOR_MIN; + BYTE* const ostart = (BYTE*)dst; +@@ -176,14 +181,14 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables + /* Sequences Header */ + RETURN_ERROR_IF((oend-op) < 3 /*max nbSeq Size*/ + 1 /*seqHead*/, + dstSize_tooSmall, ""); +- if (nbSeq < 0x7F) ++ if (nbSeq < 128) + *op++ = (BYTE)nbSeq; + else if (nbSeq < LONGNBSEQ) + op[0] = (BYTE)((nbSeq>>8) + 0x80), op[1] = (BYTE)nbSeq, op+=2; + else + op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3; + if (nbSeq==0) { +- return op - ostart; ++ return (size_t)(op - ostart); + } + + /* seqHead : flags for FSE encoding type */ +@@ -205,7 +210,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables + } + + { size_t const bitstreamSize = ZSTD_encodeSequences( +- op, oend - op, ++ op, (size_t)(oend - op), + fseTables->matchlengthCTable, mlCode, + fseTables->offcodeCTable, ofCode, + fseTables->litlengthCTable, llCode, +@@ -249,7 +254,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables + #endif + + *entropyWritten = 1; +- return op - ostart; ++ return (size_t)(op - ostart); + } + + /* ZSTD_compressSubBlock() : +@@ -275,7 +280,8 @@ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, + litSize, nbSeq, writeLitEntropy, writeSeqEntropy, lastBlock); + { size_t cLitSize = ZSTD_compressSubBlock_literal((const HUF_CElt*)entropy->huf.CTable, + &entropyMetadata->hufMetadata, literals, litSize, +- op, oend-op, bmi2, writeLitEntropy, litEntropyWritten); ++ op, (size_t)(oend-op), ++ bmi2, writeLitEntropy, litEntropyWritten); + FORWARD_IF_ERROR(cLitSize, "ZSTD_compressSubBlock_literal failed"); + if (cLitSize == 0) return 0; + op += cLitSize; +@@ -285,18 +291,18 @@ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, + sequences, nbSeq, + llCode, mlCode, ofCode, + cctxParams, +- op, oend-op, ++ op, (size_t)(oend-op), + bmi2, writeSeqEntropy, seqEntropyWritten); + FORWARD_IF_ERROR(cSeqSize, "ZSTD_compressSubBlock_sequences failed"); + if (cSeqSize == 0) return 0; + op += cSeqSize; + } + /* Write block header */ +- { size_t cSize = (op-ostart)-ZSTD_blockHeaderSize; ++ { size_t cSize = (size_t)(op-ostart) - ZSTD_blockHeaderSize; + U32 const cBlockHeader24 = lastBlock + (((U32)bt_compressed)<<1) + (U32)(cSize << 3); + MEM_writeLE24(ostart, cBlockHeader24); + } +- return op-ostart; ++ return (size_t)(op-ostart); + } + + static size_t ZSTD_estimateSubBlockSize_literal(const BYTE* literals, size_t litSize, +@@ -385,7 +391,11 @@ static size_t ZSTD_estimateSubBlockSize_sequences(const BYTE* ofCodeTable, + return cSeqSizeEstimate + sequencesSectionHeaderSize; + } + +-static size_t ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, ++typedef struct { ++ size_t estLitSize; ++ size_t estBlockSize; ++} EstimatedBlockSize; ++static EstimatedBlockSize ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, + const BYTE* ofCodeTable, + const BYTE* llCodeTable, + const BYTE* mlCodeTable, +@@ -393,15 +403,17 @@ static size_t ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, + const ZSTD_entropyCTables_t* entropy, + const ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize, +- int writeLitEntropy, int writeSeqEntropy) { +- size_t cSizeEstimate = 0; +- cSizeEstimate += ZSTD_estimateSubBlockSize_literal(literals, litSize, +- &entropy->huf, &entropyMetadata->hufMetadata, +- workspace, wkspSize, writeLitEntropy); +- cSizeEstimate += ZSTD_estimateSubBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, ++ int writeLitEntropy, int writeSeqEntropy) ++{ ++ EstimatedBlockSize ebs; ++ ebs.estLitSize = ZSTD_estimateSubBlockSize_literal(literals, litSize, ++ &entropy->huf, &entropyMetadata->hufMetadata, ++ workspace, wkspSize, writeLitEntropy); ++ ebs.estBlockSize = ZSTD_estimateSubBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, + nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, + workspace, wkspSize, writeSeqEntropy); +- return cSizeEstimate + ZSTD_blockHeaderSize; ++ ebs.estBlockSize += ebs.estLitSize + ZSTD_blockHeaderSize; ++ return ebs; + } + + static int ZSTD_needSequenceEntropyTables(ZSTD_fseCTablesMetadata_t const* fseMetadata) +@@ -415,13 +427,56 @@ static int ZSTD_needSequenceEntropyTables(ZSTD_fseCTablesMetadata_t const* fseMe + return 0; + } + ++static size_t countLiterals(seqStore_t const* seqStore, const seqDef* sp, size_t seqCount) ++{ ++ size_t n, total = 0; ++ assert(sp != NULL); ++ for (n=0; n %zu bytes", seqCount, (const void*)sp, total); ++ return total; ++} ++ ++#define BYTESCALE 256 ++ ++static size_t sizeBlockSequences(const seqDef* sp, size_t nbSeqs, ++ size_t targetBudget, size_t avgLitCost, size_t avgSeqCost, ++ int firstSubBlock) ++{ ++ size_t n, budget = 0, inSize=0; ++ /* entropy headers */ ++ size_t const headerSize = (size_t)firstSubBlock * 120 * BYTESCALE; /* generous estimate */ ++ assert(firstSubBlock==0 || firstSubBlock==1); ++ budget += headerSize; ++ ++ /* first sequence => at least one sequence*/ ++ budget += sp[0].litLength * avgLitCost + avgSeqCost; ++ if (budget > targetBudget) return 1; ++ inSize = sp[0].litLength + (sp[0].mlBase+MINMATCH); ++ ++ /* loop over sequences */ ++ for (n=1; n targetBudget) ++ /* though continue to expand until the sub-block is deemed compressible */ ++ && (budget < inSize * BYTESCALE) ) ++ break; ++ } ++ ++ return n; ++} ++ + /* ZSTD_compressSubBlock_multi() : + * Breaks super-block into multiple sub-blocks and compresses them. +- * Entropy will be written to the first block. +- * The following blocks will use repeat mode to compress. +- * All sub-blocks are compressed blocks (no raw or rle blocks). +- * @return : compressed size of the super block (which is multiple ZSTD blocks) +- * Or 0 if it failed to compress. */ ++ * Entropy will be written into the first block. ++ * The following blocks use repeat_mode to compress. ++ * Sub-blocks are all compressed, except the last one when beneficial. ++ * @return : compressed size of the super block (which features multiple ZSTD blocks) ++ * or 0 if it failed to compress. */ + static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, + const ZSTD_compressedBlockState_t* prevCBlock, + ZSTD_compressedBlockState_t* nextCBlock, +@@ -434,10 +489,12 @@ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, + { + const seqDef* const sstart = seqStorePtr->sequencesStart; + const seqDef* const send = seqStorePtr->sequences; +- const seqDef* sp = sstart; ++ const seqDef* sp = sstart; /* tracks progresses within seqStorePtr->sequences */ ++ size_t const nbSeqs = (size_t)(send - sstart); + const BYTE* const lstart = seqStorePtr->litStart; + const BYTE* const lend = seqStorePtr->lit; + const BYTE* lp = lstart; ++ size_t const nbLiterals = (size_t)(lend - lstart); + BYTE const* ip = (BYTE const*)src; + BYTE const* const iend = ip + srcSize; + BYTE* const ostart = (BYTE*)dst; +@@ -446,112 +503,171 @@ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, + const BYTE* llCodePtr = seqStorePtr->llCode; + const BYTE* mlCodePtr = seqStorePtr->mlCode; + const BYTE* ofCodePtr = seqStorePtr->ofCode; +- size_t targetCBlockSize = cctxParams->targetCBlockSize; +- size_t litSize, seqCount; +- int writeLitEntropy = entropyMetadata->hufMetadata.hType == set_compressed; ++ size_t const minTarget = ZSTD_TARGETCBLOCKSIZE_MIN; /* enforce minimum size, to reduce undesirable side effects */ ++ size_t const targetCBlockSize = MAX(minTarget, cctxParams->targetCBlockSize); ++ int writeLitEntropy = (entropyMetadata->hufMetadata.hType == set_compressed); + int writeSeqEntropy = 1; +- int lastSequence = 0; +- +- DEBUGLOG(5, "ZSTD_compressSubBlock_multi (litSize=%u, nbSeq=%u)", +- (unsigned)(lend-lp), (unsigned)(send-sstart)); +- +- litSize = 0; +- seqCount = 0; +- do { +- size_t cBlockSizeEstimate = 0; +- if (sstart == send) { +- lastSequence = 1; +- } else { +- const seqDef* const sequence = sp + seqCount; +- lastSequence = sequence == send - 1; +- litSize += ZSTD_getSequenceLength(seqStorePtr, sequence).litLength; +- seqCount++; +- } +- if (lastSequence) { +- assert(lp <= lend); +- assert(litSize <= (size_t)(lend - lp)); +- litSize = (size_t)(lend - lp); ++ ++ DEBUGLOG(5, "ZSTD_compressSubBlock_multi (srcSize=%u, litSize=%u, nbSeq=%u)", ++ (unsigned)srcSize, (unsigned)(lend-lstart), (unsigned)(send-sstart)); ++ ++ /* let's start by a general estimation for the full block */ ++ if (nbSeqs > 0) { ++ EstimatedBlockSize const ebs = ++ ZSTD_estimateSubBlockSize(lp, nbLiterals, ++ ofCodePtr, llCodePtr, mlCodePtr, nbSeqs, ++ &nextCBlock->entropy, entropyMetadata, ++ workspace, wkspSize, ++ writeLitEntropy, writeSeqEntropy); ++ /* quick estimation */ ++ size_t const avgLitCost = nbLiterals ? (ebs.estLitSize * BYTESCALE) / nbLiterals : BYTESCALE; ++ size_t const avgSeqCost = ((ebs.estBlockSize - ebs.estLitSize) * BYTESCALE) / nbSeqs; ++ const size_t nbSubBlocks = MAX((ebs.estBlockSize + (targetCBlockSize/2)) / targetCBlockSize, 1); ++ size_t n, avgBlockBudget, blockBudgetSupp=0; ++ avgBlockBudget = (ebs.estBlockSize * BYTESCALE) / nbSubBlocks; ++ DEBUGLOG(5, "estimated fullblock size=%u bytes ; avgLitCost=%.2f ; avgSeqCost=%.2f ; targetCBlockSize=%u, nbSubBlocks=%u ; avgBlockBudget=%.0f bytes", ++ (unsigned)ebs.estBlockSize, (double)avgLitCost/BYTESCALE, (double)avgSeqCost/BYTESCALE, ++ (unsigned)targetCBlockSize, (unsigned)nbSubBlocks, (double)avgBlockBudget/BYTESCALE); ++ /* simplification: if estimates states that the full superblock doesn't compress, just bail out immediately ++ * this will result in the production of a single uncompressed block covering @srcSize.*/ ++ if (ebs.estBlockSize > srcSize) return 0; ++ ++ /* compress and write sub-blocks */ ++ assert(nbSubBlocks>0); ++ for (n=0; n < nbSubBlocks-1; n++) { ++ /* determine nb of sequences for current sub-block + nbLiterals from next sequence */ ++ size_t const seqCount = sizeBlockSequences(sp, (size_t)(send-sp), ++ avgBlockBudget + blockBudgetSupp, avgLitCost, avgSeqCost, n==0); ++ /* if reached last sequence : break to last sub-block (simplification) */ ++ assert(seqCount <= (size_t)(send-sp)); ++ if (sp + seqCount == send) break; ++ assert(seqCount > 0); ++ /* compress sub-block */ ++ { int litEntropyWritten = 0; ++ int seqEntropyWritten = 0; ++ size_t litSize = countLiterals(seqStorePtr, sp, seqCount); ++ const size_t decompressedSize = ++ ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, 0); ++ size_t const cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, ++ sp, seqCount, ++ lp, litSize, ++ llCodePtr, mlCodePtr, ofCodePtr, ++ cctxParams, ++ op, (size_t)(oend-op), ++ bmi2, writeLitEntropy, writeSeqEntropy, ++ &litEntropyWritten, &seqEntropyWritten, ++ 0); ++ FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); ++ ++ /* check compressibility, update state components */ ++ if (cSize > 0 && cSize < decompressedSize) { ++ DEBUGLOG(5, "Committed sub-block compressing %u bytes => %u bytes", ++ (unsigned)decompressedSize, (unsigned)cSize); ++ assert(ip + decompressedSize <= iend); ++ ip += decompressedSize; ++ lp += litSize; ++ op += cSize; ++ llCodePtr += seqCount; ++ mlCodePtr += seqCount; ++ ofCodePtr += seqCount; ++ /* Entropy only needs to be written once */ ++ if (litEntropyWritten) { ++ writeLitEntropy = 0; ++ } ++ if (seqEntropyWritten) { ++ writeSeqEntropy = 0; ++ } ++ sp += seqCount; ++ blockBudgetSupp = 0; ++ } } ++ /* otherwise : do not compress yet, coalesce current sub-block with following one */ + } +- /* I think there is an optimization opportunity here. +- * Calling ZSTD_estimateSubBlockSize for every sequence can be wasteful +- * since it recalculates estimate from scratch. +- * For example, it would recount literal distribution and symbol codes every time. +- */ +- cBlockSizeEstimate = ZSTD_estimateSubBlockSize(lp, litSize, ofCodePtr, llCodePtr, mlCodePtr, seqCount, +- &nextCBlock->entropy, entropyMetadata, +- workspace, wkspSize, writeLitEntropy, writeSeqEntropy); +- if (cBlockSizeEstimate > targetCBlockSize || lastSequence) { +- int litEntropyWritten = 0; +- int seqEntropyWritten = 0; +- const size_t decompressedSize = ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, lastSequence); +- const size_t cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, +- sp, seqCount, +- lp, litSize, +- llCodePtr, mlCodePtr, ofCodePtr, +- cctxParams, +- op, oend-op, +- bmi2, writeLitEntropy, writeSeqEntropy, +- &litEntropyWritten, &seqEntropyWritten, +- lastBlock && lastSequence); +- FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); +- if (cSize > 0 && cSize < decompressedSize) { +- DEBUGLOG(5, "Committed the sub-block"); +- assert(ip + decompressedSize <= iend); +- ip += decompressedSize; +- sp += seqCount; +- lp += litSize; +- op += cSize; +- llCodePtr += seqCount; +- mlCodePtr += seqCount; +- ofCodePtr += seqCount; +- litSize = 0; +- seqCount = 0; +- /* Entropy only needs to be written once */ +- if (litEntropyWritten) { +- writeLitEntropy = 0; +- } +- if (seqEntropyWritten) { +- writeSeqEntropy = 0; +- } ++ } /* if (nbSeqs > 0) */ ++ ++ /* write last block */ ++ DEBUGLOG(5, "Generate last sub-block: %u sequences remaining", (unsigned)(send - sp)); ++ { int litEntropyWritten = 0; ++ int seqEntropyWritten = 0; ++ size_t litSize = (size_t)(lend - lp); ++ size_t seqCount = (size_t)(send - sp); ++ const size_t decompressedSize = ++ ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, 1); ++ size_t const cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, ++ sp, seqCount, ++ lp, litSize, ++ llCodePtr, mlCodePtr, ofCodePtr, ++ cctxParams, ++ op, (size_t)(oend-op), ++ bmi2, writeLitEntropy, writeSeqEntropy, ++ &litEntropyWritten, &seqEntropyWritten, ++ lastBlock); ++ FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); ++ ++ /* update pointers, the nb of literals borrowed from next sequence must be preserved */ ++ if (cSize > 0 && cSize < decompressedSize) { ++ DEBUGLOG(5, "Last sub-block compressed %u bytes => %u bytes", ++ (unsigned)decompressedSize, (unsigned)cSize); ++ assert(ip + decompressedSize <= iend); ++ ip += decompressedSize; ++ lp += litSize; ++ op += cSize; ++ llCodePtr += seqCount; ++ mlCodePtr += seqCount; ++ ofCodePtr += seqCount; ++ /* Entropy only needs to be written once */ ++ if (litEntropyWritten) { ++ writeLitEntropy = 0; + } ++ if (seqEntropyWritten) { ++ writeSeqEntropy = 0; ++ } ++ sp += seqCount; + } +- } while (!lastSequence); ++ } ++ ++ + if (writeLitEntropy) { +- DEBUGLOG(5, "ZSTD_compressSubBlock_multi has literal entropy tables unwritten"); ++ DEBUGLOG(5, "Literal entropy tables were never written"); + ZSTD_memcpy(&nextCBlock->entropy.huf, &prevCBlock->entropy.huf, sizeof(prevCBlock->entropy.huf)); + } + if (writeSeqEntropy && ZSTD_needSequenceEntropyTables(&entropyMetadata->fseMetadata)) { + /* If we haven't written our entropy tables, then we've violated our contract and + * must emit an uncompressed block. + */ +- DEBUGLOG(5, "ZSTD_compressSubBlock_multi has sequence entropy tables unwritten"); ++ DEBUGLOG(5, "Sequence entropy tables were never written => cancel, emit an uncompressed block"); + return 0; + } ++ + if (ip < iend) { +- size_t const cSize = ZSTD_noCompressBlock(op, oend - op, ip, iend - ip, lastBlock); +- DEBUGLOG(5, "ZSTD_compressSubBlock_multi last sub-block uncompressed, %zu bytes", (size_t)(iend - ip)); ++ /* some data left : last part of the block sent uncompressed */ ++ size_t const rSize = (size_t)((iend - ip)); ++ size_t const cSize = ZSTD_noCompressBlock(op, (size_t)(oend - op), ip, rSize, lastBlock); ++ DEBUGLOG(5, "Generate last uncompressed sub-block of %u bytes", (unsigned)(rSize)); + FORWARD_IF_ERROR(cSize, "ZSTD_noCompressBlock failed"); + assert(cSize != 0); + op += cSize; + /* We have to regenerate the repcodes because we've skipped some sequences */ + if (sp < send) { +- seqDef const* seq; ++ const seqDef* seq; + repcodes_t rep; + ZSTD_memcpy(&rep, prevCBlock->rep, sizeof(rep)); + for (seq = sstart; seq < sp; ++seq) { +- ZSTD_updateRep(rep.rep, seq->offBase - 1, ZSTD_getSequenceLength(seqStorePtr, seq).litLength == 0); ++ ZSTD_updateRep(rep.rep, seq->offBase, ZSTD_getSequenceLength(seqStorePtr, seq).litLength == 0); + } + ZSTD_memcpy(nextCBlock->rep, &rep, sizeof(rep)); + } + } +- DEBUGLOG(5, "ZSTD_compressSubBlock_multi compressed"); +- return op-ostart; ++ ++ DEBUGLOG(5, "ZSTD_compressSubBlock_multi compressed all subBlocks: total compressed size = %u", ++ (unsigned)(op-ostart)); ++ return (size_t)(op-ostart); + } + + size_t ZSTD_compressSuperBlock(ZSTD_CCtx* zc, + void* dst, size_t dstCapacity, +- void const* src, size_t srcSize, +- unsigned lastBlock) { ++ const void* src, size_t srcSize, ++ unsigned lastBlock) ++{ + ZSTD_entropyCTablesMetadata_t entropyMetadata; + + FORWARD_IF_ERROR(ZSTD_buildBlockEntropyStats(&zc->seqStore, +diff --git a/lib/zstd/compress/zstd_compress_superblock.h b/lib/zstd/compress/zstd_compress_superblock.h +index 224ece79546e..826bbc9e029b 100644 +--- a/lib/zstd/compress/zstd_compress_superblock.h ++++ b/lib/zstd/compress/zstd_compress_superblock.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/compress/zstd_cwksp.h b/lib/zstd/compress/zstd_cwksp.h +index 349fc923c355..86bc3c2c23c7 100644 +--- a/lib/zstd/compress/zstd_cwksp.h ++++ b/lib/zstd/compress/zstd_cwksp.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -14,7 +15,9 @@ + /*-************************************* + * Dependencies + ***************************************/ ++#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customFree */ + #include "../common/zstd_internal.h" ++#include "../common/portability_macros.h" + + + /*-************************************* +@@ -41,8 +44,9 @@ + ***************************************/ + typedef enum { + ZSTD_cwksp_alloc_objects, +- ZSTD_cwksp_alloc_buffers, +- ZSTD_cwksp_alloc_aligned ++ ZSTD_cwksp_alloc_aligned_init_once, ++ ZSTD_cwksp_alloc_aligned, ++ ZSTD_cwksp_alloc_buffers + } ZSTD_cwksp_alloc_phase_e; + + /* +@@ -95,8 +99,8 @@ typedef enum { + * + * Workspace Layout: + * +- * [ ... workspace ... ] +- * [objects][tables ... ->] free space [<- ... aligned][<- ... buffers] ++ * [ ... workspace ... ] ++ * [objects][tables ->] free space [<- buffers][<- aligned][<- init once] + * + * The various objects that live in the workspace are divided into the + * following categories, and are allocated separately: +@@ -120,9 +124,18 @@ typedef enum { + * uint32_t arrays, all of whose values are between 0 and (nextSrc - base). + * Their sizes depend on the cparams. These tables are 64-byte aligned. + * +- * - Aligned: these buffers are used for various purposes that require 4 byte +- * alignment, but don't require any initialization before they're used. These +- * buffers are each aligned to 64 bytes. ++ * - Init once: these buffers require to be initialized at least once before ++ * use. They should be used when we want to skip memory initialization ++ * while not triggering memory checkers (like Valgrind) when reading from ++ * from this memory without writing to it first. ++ * These buffers should be used carefully as they might contain data ++ * from previous compressions. ++ * Buffers are aligned to 64 bytes. ++ * ++ * - Aligned: these buffers don't require any initialization before they're ++ * used. The user of the buffer should make sure they write into a buffer ++ * location before reading from it. ++ * Buffers are aligned to 64 bytes. + * + * - Buffers: these buffers are used for various purposes that don't require + * any alignment or initialization before they're used. This means they can +@@ -134,8 +147,9 @@ typedef enum { + * correctly packed into the workspace buffer. That order is: + * + * 1. Objects +- * 2. Buffers +- * 3. Aligned/Tables ++ * 2. Init once / Tables ++ * 3. Aligned / Tables ++ * 4. Buffers / Tables + * + * Attempts to reserve objects of different types out of order will fail. + */ +@@ -147,6 +161,7 @@ typedef struct { + void* tableEnd; + void* tableValidEnd; + void* allocStart; ++ void* initOnceStart; + + BYTE allocFailed; + int workspaceOversizedDuration; +@@ -159,6 +174,7 @@ typedef struct { + ***************************************/ + + MEM_STATIC size_t ZSTD_cwksp_available_space(ZSTD_cwksp* ws); ++MEM_STATIC void* ZSTD_cwksp_initialAllocStart(ZSTD_cwksp* ws); + + MEM_STATIC void ZSTD_cwksp_assert_internal_consistency(ZSTD_cwksp* ws) { + (void)ws; +@@ -168,6 +184,8 @@ MEM_STATIC void ZSTD_cwksp_assert_internal_consistency(ZSTD_cwksp* ws) { + assert(ws->tableEnd <= ws->allocStart); + assert(ws->tableValidEnd <= ws->allocStart); + assert(ws->allocStart <= ws->workspaceEnd); ++ assert(ws->initOnceStart <= ZSTD_cwksp_initialAllocStart(ws)); ++ assert(ws->workspace <= ws->initOnceStart); + } + + /* +@@ -210,14 +228,10 @@ MEM_STATIC size_t ZSTD_cwksp_aligned_alloc_size(size_t size) { + * for internal purposes (currently only alignment). + */ + MEM_STATIC size_t ZSTD_cwksp_slack_space_required(void) { +- /* For alignment, the wksp will always allocate an additional n_1=[1, 64] bytes +- * to align the beginning of tables section, as well as another n_2=[0, 63] bytes +- * to align the beginning of the aligned section. +- * +- * n_1 + n_2 == 64 bytes if the cwksp is freshly allocated, due to tables and +- * aligneds being sized in multiples of 64 bytes. ++ /* For alignment, the wksp will always allocate an additional 2*ZSTD_CWKSP_ALIGNMENT_BYTES ++ * bytes to align the beginning of tables section and end of buffers; + */ +- size_t const slackSpace = ZSTD_CWKSP_ALIGNMENT_BYTES; ++ size_t const slackSpace = ZSTD_CWKSP_ALIGNMENT_BYTES * 2; + return slackSpace; + } + +@@ -230,10 +244,18 @@ MEM_STATIC size_t ZSTD_cwksp_bytes_to_align_ptr(void* ptr, const size_t alignByt + size_t const alignBytesMask = alignBytes - 1; + size_t const bytes = (alignBytes - ((size_t)ptr & (alignBytesMask))) & alignBytesMask; + assert((alignBytes & alignBytesMask) == 0); +- assert(bytes != ZSTD_CWKSP_ALIGNMENT_BYTES); ++ assert(bytes < alignBytes); + return bytes; + } + ++/* ++ * Returns the initial value for allocStart which is used to determine the position from ++ * which we can allocate from the end of the workspace. ++ */ ++MEM_STATIC void* ZSTD_cwksp_initialAllocStart(ZSTD_cwksp* ws) { ++ return (void*)((size_t)ws->workspaceEnd & ~(ZSTD_CWKSP_ALIGNMENT_BYTES-1)); ++} ++ + /* + * Internal function. Do not use directly. + * Reserves the given number of bytes within the aligned/buffer segment of the wksp, +@@ -274,27 +296,16 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase + { + assert(phase >= ws->phase); + if (phase > ws->phase) { +- /* Going from allocating objects to allocating buffers */ +- if (ws->phase < ZSTD_cwksp_alloc_buffers && +- phase >= ZSTD_cwksp_alloc_buffers) { ++ /* Going from allocating objects to allocating initOnce / tables */ ++ if (ws->phase < ZSTD_cwksp_alloc_aligned_init_once && ++ phase >= ZSTD_cwksp_alloc_aligned_init_once) { + ws->tableValidEnd = ws->objectEnd; +- } ++ ws->initOnceStart = ZSTD_cwksp_initialAllocStart(ws); + +- /* Going from allocating buffers to allocating aligneds/tables */ +- if (ws->phase < ZSTD_cwksp_alloc_aligned && +- phase >= ZSTD_cwksp_alloc_aligned) { +- { /* Align the start of the "aligned" to 64 bytes. Use [1, 64] bytes. */ +- size_t const bytesToAlign = +- ZSTD_CWKSP_ALIGNMENT_BYTES - ZSTD_cwksp_bytes_to_align_ptr(ws->allocStart, ZSTD_CWKSP_ALIGNMENT_BYTES); +- DEBUGLOG(5, "reserving aligned alignment addtl space: %zu", bytesToAlign); +- ZSTD_STATIC_ASSERT((ZSTD_CWKSP_ALIGNMENT_BYTES & (ZSTD_CWKSP_ALIGNMENT_BYTES - 1)) == 0); /* power of 2 */ +- RETURN_ERROR_IF(!ZSTD_cwksp_reserve_internal_buffer_space(ws, bytesToAlign), +- memory_allocation, "aligned phase - alignment initial allocation failed!"); +- } + { /* Align the start of the tables to 64 bytes. Use [0, 63] bytes */ +- void* const alloc = ws->objectEnd; ++ void *const alloc = ws->objectEnd; + size_t const bytesToAlign = ZSTD_cwksp_bytes_to_align_ptr(alloc, ZSTD_CWKSP_ALIGNMENT_BYTES); +- void* const objectEnd = (BYTE*)alloc + bytesToAlign; ++ void *const objectEnd = (BYTE *) alloc + bytesToAlign; + DEBUGLOG(5, "reserving table alignment addtl space: %zu", bytesToAlign); + RETURN_ERROR_IF(objectEnd > ws->workspaceEnd, memory_allocation, + "table phase - alignment initial allocation failed!"); +@@ -302,7 +313,9 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase + ws->tableEnd = objectEnd; /* table area starts being empty */ + if (ws->tableValidEnd < ws->tableEnd) { + ws->tableValidEnd = ws->tableEnd; +- } } } ++ } ++ } ++ } + ws->phase = phase; + ZSTD_cwksp_assert_internal_consistency(ws); + } +@@ -314,7 +327,7 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase + */ + MEM_STATIC int ZSTD_cwksp_owns_buffer(const ZSTD_cwksp* ws, const void* ptr) + { +- return (ptr != NULL) && (ws->workspace <= ptr) && (ptr <= ws->workspaceEnd); ++ return (ptr != NULL) && (ws->workspace <= ptr) && (ptr < ws->workspaceEnd); + } + + /* +@@ -343,6 +356,33 @@ MEM_STATIC BYTE* ZSTD_cwksp_reserve_buffer(ZSTD_cwksp* ws, size_t bytes) + return (BYTE*)ZSTD_cwksp_reserve_internal(ws, bytes, ZSTD_cwksp_alloc_buffers); + } + ++/* ++ * Reserves and returns memory sized on and aligned on ZSTD_CWKSP_ALIGNMENT_BYTES (64 bytes). ++ * This memory has been initialized at least once in the past. ++ * This doesn't mean it has been initialized this time, and it might contain data from previous ++ * operations. ++ * The main usage is for algorithms that might need read access into uninitialized memory. ++ * The algorithm must maintain safety under these conditions and must make sure it doesn't ++ * leak any of the past data (directly or in side channels). ++ */ ++MEM_STATIC void* ZSTD_cwksp_reserve_aligned_init_once(ZSTD_cwksp* ws, size_t bytes) ++{ ++ size_t const alignedBytes = ZSTD_cwksp_align(bytes, ZSTD_CWKSP_ALIGNMENT_BYTES); ++ void* ptr = ZSTD_cwksp_reserve_internal(ws, alignedBytes, ZSTD_cwksp_alloc_aligned_init_once); ++ assert(((size_t)ptr & (ZSTD_CWKSP_ALIGNMENT_BYTES-1))== 0); ++ if(ptr && ptr < ws->initOnceStart) { ++ /* We assume the memory following the current allocation is either: ++ * 1. Not usable as initOnce memory (end of workspace) ++ * 2. Another initOnce buffer that has been allocated before (and so was previously memset) ++ * 3. An ASAN redzone, in which case we don't want to write on it ++ * For these reasons it should be fine to not explicitly zero every byte up to ws->initOnceStart. ++ * Note that we assume here that MSAN and ASAN cannot run in the same time. */ ++ ZSTD_memset(ptr, 0, MIN((size_t)((U8*)ws->initOnceStart - (U8*)ptr), alignedBytes)); ++ ws->initOnceStart = ptr; ++ } ++ return ptr; ++} ++ + /* + * Reserves and returns memory sized on and aligned on ZSTD_CWKSP_ALIGNMENT_BYTES (64 bytes). + */ +@@ -356,18 +396,22 @@ MEM_STATIC void* ZSTD_cwksp_reserve_aligned(ZSTD_cwksp* ws, size_t bytes) + + /* + * Aligned on 64 bytes. These buffers have the special property that +- * their values remain constrained, allowing us to re-use them without ++ * their values remain constrained, allowing us to reuse them without + * memset()-ing them. + */ + MEM_STATIC void* ZSTD_cwksp_reserve_table(ZSTD_cwksp* ws, size_t bytes) + { +- const ZSTD_cwksp_alloc_phase_e phase = ZSTD_cwksp_alloc_aligned; ++ const ZSTD_cwksp_alloc_phase_e phase = ZSTD_cwksp_alloc_aligned_init_once; + void* alloc; + void* end; + void* top; + +- if (ZSTD_isError(ZSTD_cwksp_internal_advance_phase(ws, phase))) { +- return NULL; ++ /* We can only start allocating tables after we are done reserving space for objects at the ++ * start of the workspace */ ++ if(ws->phase < phase) { ++ if (ZSTD_isError(ZSTD_cwksp_internal_advance_phase(ws, phase))) { ++ return NULL; ++ } + } + alloc = ws->tableEnd; + end = (BYTE *)alloc + bytes; +@@ -451,7 +495,7 @@ MEM_STATIC void ZSTD_cwksp_clean_tables(ZSTD_cwksp* ws) { + assert(ws->tableValidEnd >= ws->objectEnd); + assert(ws->tableValidEnd <= ws->allocStart); + if (ws->tableValidEnd < ws->tableEnd) { +- ZSTD_memset(ws->tableValidEnd, 0, (BYTE*)ws->tableEnd - (BYTE*)ws->tableValidEnd); ++ ZSTD_memset(ws->tableValidEnd, 0, (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->tableValidEnd)); + } + ZSTD_cwksp_mark_tables_clean(ws); + } +@@ -478,14 +522,23 @@ MEM_STATIC void ZSTD_cwksp_clear(ZSTD_cwksp* ws) { + + + ws->tableEnd = ws->objectEnd; +- ws->allocStart = ws->workspaceEnd; ++ ws->allocStart = ZSTD_cwksp_initialAllocStart(ws); + ws->allocFailed = 0; +- if (ws->phase > ZSTD_cwksp_alloc_buffers) { +- ws->phase = ZSTD_cwksp_alloc_buffers; ++ if (ws->phase > ZSTD_cwksp_alloc_aligned_init_once) { ++ ws->phase = ZSTD_cwksp_alloc_aligned_init_once; + } + ZSTD_cwksp_assert_internal_consistency(ws); + } + ++MEM_STATIC size_t ZSTD_cwksp_sizeof(const ZSTD_cwksp* ws) { ++ return (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->workspace); ++} ++ ++MEM_STATIC size_t ZSTD_cwksp_used(const ZSTD_cwksp* ws) { ++ return (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->workspace) ++ + (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->allocStart); ++} ++ + /* + * The provided workspace takes ownership of the buffer [start, start+size). + * Any existing values in the workspace are ignored (the previously managed +@@ -498,6 +551,7 @@ MEM_STATIC void ZSTD_cwksp_init(ZSTD_cwksp* ws, void* start, size_t size, ZSTD_c + ws->workspaceEnd = (BYTE*)start + size; + ws->objectEnd = ws->workspace; + ws->tableValidEnd = ws->objectEnd; ++ ws->initOnceStart = ZSTD_cwksp_initialAllocStart(ws); + ws->phase = ZSTD_cwksp_alloc_objects; + ws->isStatic = isStatic; + ZSTD_cwksp_clear(ws); +@@ -529,15 +583,6 @@ MEM_STATIC void ZSTD_cwksp_move(ZSTD_cwksp* dst, ZSTD_cwksp* src) { + ZSTD_memset(src, 0, sizeof(ZSTD_cwksp)); + } + +-MEM_STATIC size_t ZSTD_cwksp_sizeof(const ZSTD_cwksp* ws) { +- return (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->workspace); +-} +- +-MEM_STATIC size_t ZSTD_cwksp_used(const ZSTD_cwksp* ws) { +- return (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->workspace) +- + (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->allocStart); +-} +- + MEM_STATIC int ZSTD_cwksp_reserve_failed(const ZSTD_cwksp* ws) { + return ws->allocFailed; + } +@@ -550,17 +595,11 @@ MEM_STATIC int ZSTD_cwksp_reserve_failed(const ZSTD_cwksp* ws) { + * Returns if the estimated space needed for a wksp is within an acceptable limit of the + * actual amount of space used. + */ +-MEM_STATIC int ZSTD_cwksp_estimated_space_within_bounds(const ZSTD_cwksp* const ws, +- size_t const estimatedSpace, int resizedWorkspace) { +- if (resizedWorkspace) { +- /* Resized/newly allocated wksp should have exact bounds */ +- return ZSTD_cwksp_used(ws) == estimatedSpace; +- } else { +- /* Due to alignment, when reusing a workspace, we can actually consume 63 fewer or more bytes +- * than estimatedSpace. See the comments in zstd_cwksp.h for details. +- */ +- return (ZSTD_cwksp_used(ws) >= estimatedSpace - 63) && (ZSTD_cwksp_used(ws) <= estimatedSpace + 63); +- } ++MEM_STATIC int ZSTD_cwksp_estimated_space_within_bounds(const ZSTD_cwksp *const ws, size_t const estimatedSpace) { ++ /* We have an alignment space between objects and tables between tables and buffers, so we can have up to twice ++ * the alignment bytes difference between estimation and actual usage */ ++ return (estimatedSpace - ZSTD_cwksp_slack_space_required()) <= ZSTD_cwksp_used(ws) && ++ ZSTD_cwksp_used(ws) <= estimatedSpace; + } + + +diff --git a/lib/zstd/compress/zstd_double_fast.c b/lib/zstd/compress/zstd_double_fast.c +index 76933dea2624..5ff54f17d92f 100644 +--- a/lib/zstd/compress/zstd_double_fast.c ++++ b/lib/zstd/compress/zstd_double_fast.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -11,8 +12,49 @@ + #include "zstd_compress_internal.h" + #include "zstd_double_fast.h" + ++#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + +-void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_fillDoubleHashTableForCDict(ZSTD_matchState_t* ms, ++ void const* end, ZSTD_dictTableLoadMethod_e dtlm) ++{ ++ const ZSTD_compressionParameters* const cParams = &ms->cParams; ++ U32* const hashLarge = ms->hashTable; ++ U32 const hBitsL = cParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; ++ U32 const mls = cParams->minMatch; ++ U32* const hashSmall = ms->chainTable; ++ U32 const hBitsS = cParams->chainLog + ZSTD_SHORT_CACHE_TAG_BITS; ++ const BYTE* const base = ms->window.base; ++ const BYTE* ip = base + ms->nextToUpdate; ++ const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; ++ const U32 fastHashFillStep = 3; ++ ++ /* Always insert every fastHashFillStep position into the hash tables. ++ * Insert the other positions into the large hash table if their entry ++ * is empty. ++ */ ++ for (; ip + fastHashFillStep - 1 <= iend; ip += fastHashFillStep) { ++ U32 const curr = (U32)(ip - base); ++ U32 i; ++ for (i = 0; i < fastHashFillStep; ++i) { ++ size_t const smHashAndTag = ZSTD_hashPtr(ip + i, hBitsS, mls); ++ size_t const lgHashAndTag = ZSTD_hashPtr(ip + i, hBitsL, 8); ++ if (i == 0) { ++ ZSTD_writeTaggedIndex(hashSmall, smHashAndTag, curr + i); ++ } ++ if (i == 0 || hashLarge[lgHashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { ++ ZSTD_writeTaggedIndex(hashLarge, lgHashAndTag, curr + i); ++ } ++ /* Only load extra positions for ZSTD_dtlm_full */ ++ if (dtlm == ZSTD_dtlm_fast) ++ break; ++ } } ++} ++ ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_fillDoubleHashTableForCCtx(ZSTD_matchState_t* ms, + void const* end, ZSTD_dictTableLoadMethod_e dtlm) + { + const ZSTD_compressionParameters* const cParams = &ms->cParams; +@@ -43,11 +85,24 @@ void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, + /* Only load extra positions for ZSTD_dtlm_full */ + if (dtlm == ZSTD_dtlm_fast) + break; +- } } ++ } } ++} ++ ++void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, ++ const void* const end, ++ ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp) ++{ ++ if (tfp == ZSTD_tfp_forCDict) { ++ ZSTD_fillDoubleHashTableForCDict(ms, end, dtlm); ++ } else { ++ ZSTD_fillDoubleHashTableForCCtx(ms, end, dtlm); ++ } + } + + + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_compressBlock_doubleFast_noDict_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize, U32 const mls /* template */) +@@ -67,7 +122,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + const BYTE* const iend = istart + srcSize; + const BYTE* const ilimit = iend - HASH_READ_SIZE; + U32 offset_1=rep[0], offset_2=rep[1]; +- U32 offsetSaved = 0; ++ U32 offsetSaved1 = 0, offsetSaved2 = 0; + + size_t mLength; + U32 offset; +@@ -100,8 +155,8 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + U32 const current = (U32)(ip - base); + U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, current, cParams->windowLog); + U32 const maxRep = current - windowLow; +- if (offset_2 > maxRep) offsetSaved = offset_2, offset_2 = 0; +- if (offset_1 > maxRep) offsetSaved = offset_1, offset_1 = 0; ++ if (offset_2 > maxRep) offsetSaved2 = offset_2, offset_2 = 0; ++ if (offset_1 > maxRep) offsetSaved1 = offset_1, offset_1 = 0; + } + + /* Outer Loop: one iteration per match found and stored */ +@@ -131,7 +186,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + if ((offset_1 > 0) & (MEM_read32(ip+1-offset_1) == MEM_read32(ip+1))) { + mLength = ZSTD_count(ip+1+4, ip+1+4-offset_1, iend) + 4; + ip++; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); + goto _match_stored; + } + +@@ -175,9 +230,13 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + } while (ip1 <= ilimit); + + _cleanup: ++ /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), ++ * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ ++ offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; ++ + /* save reps for next block */ +- rep[0] = offset_1 ? offset_1 : offsetSaved; +- rep[1] = offset_2 ? offset_2 : offsetSaved; ++ rep[0] = offset_1 ? offset_1 : offsetSaved1; ++ rep[1] = offset_2 ? offset_2 : offsetSaved2; + + /* Return the last literals size */ + return (size_t)(iend - anchor); +@@ -217,7 +276,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + hashLong[hl1] = (U32)(ip1 - base); + } + +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + + _match_stored: + /* match found */ +@@ -243,7 +302,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + U32 const tmpOff = offset_2; offset_2 = offset_1; offset_1 = tmpOff; /* swap offset_2 <=> offset_1 */ + hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = (U32)(ip-base); + hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = (U32)(ip-base); +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, rLength); ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, rLength); + ip += rLength; + anchor = ip; + continue; /* faster when present ... (?) */ +@@ -254,6 +313,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( + + + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize, +@@ -275,7 +335,6 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + const BYTE* const iend = istart + srcSize; + const BYTE* const ilimit = iend - HASH_READ_SIZE; + U32 offset_1=rep[0], offset_2=rep[1]; +- U32 offsetSaved = 0; + + const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_compressionParameters* const dictCParams = &dms->cParams; +@@ -286,8 +345,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + const BYTE* const dictStart = dictBase + dictStartIndex; + const BYTE* const dictEnd = dms->window.nextSrc; + const U32 dictIndexDelta = prefixLowestIndex - (U32)(dictEnd - dictBase); +- const U32 dictHBitsL = dictCParams->hashLog; +- const U32 dictHBitsS = dictCParams->chainLog; ++ const U32 dictHBitsL = dictCParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; ++ const U32 dictHBitsS = dictCParams->chainLog + ZSTD_SHORT_CACHE_TAG_BITS; + const U32 dictAndPrefixLength = (U32)((ip - prefixLowest) + (dictEnd - dictStart)); + + DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_dictMatchState_generic"); +@@ -295,6 +354,13 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + /* if a dictionary is attached, it must be within window range */ + assert(ms->window.dictLimit + (1U << cParams->windowLog) >= endIndex); + ++ if (ms->prefetchCDictTables) { ++ size_t const hashTableBytes = (((size_t)1) << dictCParams->hashLog) * sizeof(U32); ++ size_t const chainTableBytes = (((size_t)1) << dictCParams->chainLog) * sizeof(U32); ++ PREFETCH_AREA(dictHashLong, hashTableBytes); ++ PREFETCH_AREA(dictHashSmall, chainTableBytes); ++ } ++ + /* init */ + ip += (dictAndPrefixLength == 0); + +@@ -309,8 +375,12 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + U32 offset; + size_t const h2 = ZSTD_hashPtr(ip, hBitsL, 8); + size_t const h = ZSTD_hashPtr(ip, hBitsS, mls); +- size_t const dictHL = ZSTD_hashPtr(ip, dictHBitsL, 8); +- size_t const dictHS = ZSTD_hashPtr(ip, dictHBitsS, mls); ++ size_t const dictHashAndTagL = ZSTD_hashPtr(ip, dictHBitsL, 8); ++ size_t const dictHashAndTagS = ZSTD_hashPtr(ip, dictHBitsS, mls); ++ U32 const dictMatchIndexAndTagL = dictHashLong[dictHashAndTagL >> ZSTD_SHORT_CACHE_TAG_BITS]; ++ U32 const dictMatchIndexAndTagS = dictHashSmall[dictHashAndTagS >> ZSTD_SHORT_CACHE_TAG_BITS]; ++ int const dictTagsMatchL = ZSTD_comparePackedTags(dictMatchIndexAndTagL, dictHashAndTagL); ++ int const dictTagsMatchS = ZSTD_comparePackedTags(dictMatchIndexAndTagS, dictHashAndTagS); + U32 const curr = (U32)(ip-base); + U32 const matchIndexL = hashLong[h2]; + U32 matchIndexS = hashSmall[h]; +@@ -328,7 +398,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; + mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; + ip++; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); + goto _match_stored; + } + +@@ -340,9 +410,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + while (((ip>anchor) & (matchLong>prefixLowest)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ + goto _match_found; + } +- } else { ++ } else if (dictTagsMatchL) { + /* check dictMatchState long match */ +- U32 const dictMatchIndexL = dictHashLong[dictHL]; ++ U32 const dictMatchIndexL = dictMatchIndexAndTagL >> ZSTD_SHORT_CACHE_TAG_BITS; + const BYTE* dictMatchL = dictBase + dictMatchIndexL; + assert(dictMatchL < dictEnd); + +@@ -358,9 +428,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + if (MEM_read32(match) == MEM_read32(ip)) { + goto _search_next_long; + } +- } else { ++ } else if (dictTagsMatchS) { + /* check dictMatchState short match */ +- U32 const dictMatchIndexS = dictHashSmall[dictHS]; ++ U32 const dictMatchIndexS = dictMatchIndexAndTagS >> ZSTD_SHORT_CACHE_TAG_BITS; + match = dictBase + dictMatchIndexS; + matchIndexS = dictMatchIndexS + dictIndexDelta; + +@@ -375,10 +445,11 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + continue; + + _search_next_long: +- + { size_t const hl3 = ZSTD_hashPtr(ip+1, hBitsL, 8); +- size_t const dictHLNext = ZSTD_hashPtr(ip+1, dictHBitsL, 8); ++ size_t const dictHashAndTagL3 = ZSTD_hashPtr(ip+1, dictHBitsL, 8); + U32 const matchIndexL3 = hashLong[hl3]; ++ U32 const dictMatchIndexAndTagL3 = dictHashLong[dictHashAndTagL3 >> ZSTD_SHORT_CACHE_TAG_BITS]; ++ int const dictTagsMatchL3 = ZSTD_comparePackedTags(dictMatchIndexAndTagL3, dictHashAndTagL3); + const BYTE* matchL3 = base + matchIndexL3; + hashLong[hl3] = curr + 1; + +@@ -391,9 +462,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + while (((ip>anchor) & (matchL3>prefixLowest)) && (ip[-1] == matchL3[-1])) { ip--; matchL3--; mLength++; } /* catch up */ + goto _match_found; + } +- } else { ++ } else if (dictTagsMatchL3) { + /* check dict long +1 match */ +- U32 const dictMatchIndexL3 = dictHashLong[dictHLNext]; ++ U32 const dictMatchIndexL3 = dictMatchIndexAndTagL3 >> ZSTD_SHORT_CACHE_TAG_BITS; + const BYTE* dictMatchL3 = dictBase + dictMatchIndexL3; + assert(dictMatchL3 < dictEnd); + if (dictMatchL3 > dictStart && MEM_read64(dictMatchL3) == MEM_read64(ip+1)) { +@@ -419,7 +490,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + offset_2 = offset_1; + offset_1 = offset; + +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + + _match_stored: + /* match found */ +@@ -448,7 +519,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + const BYTE* const repEnd2 = repIndex2 < prefixLowestIndex ? dictEnd : iend; + size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixLowest) + 4; + U32 tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); + hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = current2; + hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = current2; + ip += repLength2; +@@ -461,8 +532,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( + } /* while (ip < ilimit) */ + + /* save reps for next block */ +- rep[0] = offset_1 ? offset_1 : offsetSaved; +- rep[1] = offset_2 ? offset_2 : offsetSaved; ++ rep[0] = offset_1; ++ rep[1] = offset_2; + + /* Return the last literals size */ + return (size_t)(iend - anchor); +@@ -527,7 +598,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState( + } + + +-static size_t ZSTD_compressBlock_doubleFast_extDict_generic( ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_compressBlock_doubleFast_extDict_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize, + U32 const mls /* template */) +@@ -585,7 +658,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( + const BYTE* repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; + mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixStart) + 4; + ip++; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); + } else { + if ((matchLongIndex > dictStartIndex) && (MEM_read64(matchLong) == MEM_read64(ip))) { + const BYTE* const matchEnd = matchLongIndex < prefixStartIndex ? dictEnd : iend; +@@ -596,7 +669,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( + while (((ip>anchor) & (matchLong>lowMatchPtr)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ + offset_2 = offset_1; + offset_1 = offset; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + + } else if ((matchIndex > dictStartIndex) && (MEM_read32(match) == MEM_read32(ip))) { + size_t const h3 = ZSTD_hashPtr(ip+1, hBitsL, 8); +@@ -621,7 +694,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( + } + offset_2 = offset_1; + offset_1 = offset; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); ++ ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + + } else { + ip += ((ip-anchor) >> kSearchStrength) + 1; +@@ -653,7 +726,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( + const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; + size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; + U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); + hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = current2; + hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = current2; + ip += repLength2; +@@ -694,3 +767,5 @@ size_t ZSTD_compressBlock_doubleFast_extDict( + return ZSTD_compressBlock_doubleFast_extDict_7(ms, seqStore, rep, src, srcSize); + } + } ++ ++#endif /* ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR */ +diff --git a/lib/zstd/compress/zstd_double_fast.h b/lib/zstd/compress/zstd_double_fast.h +index 6822bde65a1d..b7ddc714f13e 100644 +--- a/lib/zstd/compress/zstd_double_fast.h ++++ b/lib/zstd/compress/zstd_double_fast.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -15,8 +16,12 @@ + #include "../common/mem.h" /* U32 */ + #include "zstd_compress_internal.h" /* ZSTD_CCtx, size_t */ + ++#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR ++ + void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, +- void const* end, ZSTD_dictTableLoadMethod_e dtlm); ++ void const* end, ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp); ++ + size_t ZSTD_compressBlock_doubleFast( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +@@ -27,6 +32,14 @@ size_t ZSTD_compressBlock_doubleFast_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); + ++#define ZSTD_COMPRESSBLOCK_DOUBLEFAST ZSTD_compressBlock_doubleFast ++#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE ZSTD_compressBlock_doubleFast_dictMatchState ++#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT ZSTD_compressBlock_doubleFast_extDict ++#else ++#define ZSTD_COMPRESSBLOCK_DOUBLEFAST NULL ++#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT NULL ++#endif /* ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR */ + + + #endif /* ZSTD_DOUBLE_FAST_H */ +diff --git a/lib/zstd/compress/zstd_fast.c b/lib/zstd/compress/zstd_fast.c +index a752e6beab52..b7a63ba4ce56 100644 +--- a/lib/zstd/compress/zstd_fast.c ++++ b/lib/zstd/compress/zstd_fast.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -11,8 +12,46 @@ + #include "zstd_compress_internal.h" /* ZSTD_hashPtr, ZSTD_count, ZSTD_storeSeq */ + #include "zstd_fast.h" + ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_fillHashTableForCDict(ZSTD_matchState_t* ms, ++ const void* const end, ++ ZSTD_dictTableLoadMethod_e dtlm) ++{ ++ const ZSTD_compressionParameters* const cParams = &ms->cParams; ++ U32* const hashTable = ms->hashTable; ++ U32 const hBits = cParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; ++ U32 const mls = cParams->minMatch; ++ const BYTE* const base = ms->window.base; ++ const BYTE* ip = base + ms->nextToUpdate; ++ const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; ++ const U32 fastHashFillStep = 3; + +-void ZSTD_fillHashTable(ZSTD_matchState_t* ms, ++ /* Currently, we always use ZSTD_dtlm_full for filling CDict tables. ++ * Feel free to remove this assert if there's a good reason! */ ++ assert(dtlm == ZSTD_dtlm_full); ++ ++ /* Always insert every fastHashFillStep position into the hash table. ++ * Insert the other positions if their hash entry is empty. ++ */ ++ for ( ; ip + fastHashFillStep < iend + 2; ip += fastHashFillStep) { ++ U32 const curr = (U32)(ip - base); ++ { size_t const hashAndTag = ZSTD_hashPtr(ip, hBits, mls); ++ ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr); } ++ ++ if (dtlm == ZSTD_dtlm_fast) continue; ++ /* Only load extra positions for ZSTD_dtlm_full */ ++ { U32 p; ++ for (p = 1; p < fastHashFillStep; ++p) { ++ size_t const hashAndTag = ZSTD_hashPtr(ip + p, hBits, mls); ++ if (hashTable[hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { /* not yet filled */ ++ ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr + p); ++ } } } } ++} ++ ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_fillHashTableForCCtx(ZSTD_matchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm) + { +@@ -25,6 +64,10 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, + const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; + const U32 fastHashFillStep = 3; + ++ /* Currently, we always use ZSTD_dtlm_fast for filling CCtx tables. ++ * Feel free to remove this assert if there's a good reason! */ ++ assert(dtlm == ZSTD_dtlm_fast); ++ + /* Always insert every fastHashFillStep position into the hash table. + * Insert the other positions if their hash entry is empty. + */ +@@ -42,6 +85,18 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, + } } } } + } + ++void ZSTD_fillHashTable(ZSTD_matchState_t* ms, ++ const void* const end, ++ ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp) ++{ ++ if (tfp == ZSTD_tfp_forCDict) { ++ ZSTD_fillHashTableForCDict(ms, end, dtlm); ++ } else { ++ ZSTD_fillHashTableForCCtx(ms, end, dtlm); ++ } ++} ++ + + /* + * If you squint hard enough (and ignore repcodes), the search operation at any +@@ -89,8 +144,9 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, + * + * This is also the work we do at the beginning to enter the loop initially. + */ +-FORCE_INLINE_TEMPLATE size_t +-ZSTD_compressBlock_fast_noDict_generic( ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_compressBlock_fast_noDict_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize, + U32 const mls, U32 const hasStep) +@@ -117,7 +173,7 @@ ZSTD_compressBlock_fast_noDict_generic( + + U32 rep_offset1 = rep[0]; + U32 rep_offset2 = rep[1]; +- U32 offsetSaved = 0; ++ U32 offsetSaved1 = 0, offsetSaved2 = 0; + + size_t hash0; /* hash for ip0 */ + size_t hash1; /* hash for ip1 */ +@@ -141,8 +197,8 @@ ZSTD_compressBlock_fast_noDict_generic( + { U32 const curr = (U32)(ip0 - base); + U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, curr, cParams->windowLog); + U32 const maxRep = curr - windowLow; +- if (rep_offset2 > maxRep) offsetSaved = rep_offset2, rep_offset2 = 0; +- if (rep_offset1 > maxRep) offsetSaved = rep_offset1, rep_offset1 = 0; ++ if (rep_offset2 > maxRep) offsetSaved2 = rep_offset2, rep_offset2 = 0; ++ if (rep_offset1 > maxRep) offsetSaved1 = rep_offset1, rep_offset1 = 0; + } + + /* start each op */ +@@ -180,8 +236,14 @@ ZSTD_compressBlock_fast_noDict_generic( + mLength = ip0[-1] == match0[-1]; + ip0 -= mLength; + match0 -= mLength; +- offcode = STORE_REPCODE_1; ++ offcode = REPCODE1_TO_OFFBASE; + mLength += 4; ++ ++ /* First write next hash table entry; we've already calculated it. ++ * This write is known to be safe because the ip1 is before the ++ * repcode (ip2). */ ++ hashTable[hash1] = (U32)(ip1 - base); ++ + goto _match; + } + +@@ -195,6 +257,12 @@ ZSTD_compressBlock_fast_noDict_generic( + /* check match at ip[0] */ + if (MEM_read32(ip0) == mval) { + /* found a match! */ ++ ++ /* First write next hash table entry; we've already calculated it. ++ * This write is known to be safe because the ip1 == ip0 + 1, so ++ * we know we will resume searching after ip1 */ ++ hashTable[hash1] = (U32)(ip1 - base); ++ + goto _offset; + } + +@@ -224,6 +292,21 @@ ZSTD_compressBlock_fast_noDict_generic( + /* check match at ip[0] */ + if (MEM_read32(ip0) == mval) { + /* found a match! */ ++ ++ /* first write next hash table entry; we've already calculated it */ ++ if (step <= 4) { ++ /* We need to avoid writing an index into the hash table >= the ++ * position at which we will pick up our searching after we've ++ * taken this match. ++ * ++ * The minimum possible match has length 4, so the earliest ip0 ++ * can be after we take this match will be the current ip0 + 4. ++ * ip1 is ip0 + step - 1. If ip1 is >= ip0 + 4, we can't safely ++ * write this position. ++ */ ++ hashTable[hash1] = (U32)(ip1 - base); ++ } ++ + goto _offset; + } + +@@ -254,9 +337,24 @@ ZSTD_compressBlock_fast_noDict_generic( + * However, it seems to be a meaningful performance hit to try to search + * them. So let's not. */ + ++ /* When the repcodes are outside of the prefix, we set them to zero before the loop. ++ * When the offsets are still zero, we need to restore them after the block to have a correct ++ * repcode history. If only one offset was invalid, it is easy. The tricky case is when both ++ * offsets were invalid. We need to figure out which offset to refill with. ++ * - If both offsets are zero they are in the same order. ++ * - If both offsets are non-zero, we won't restore the offsets from `offsetSaved[12]`. ++ * - If only one is zero, we need to decide which offset to restore. ++ * - If rep_offset1 is non-zero, then rep_offset2 must be offsetSaved1. ++ * - It is impossible for rep_offset2 to be non-zero. ++ * ++ * So if rep_offset1 started invalid (offsetSaved1 != 0) and became valid (rep_offset1 != 0), then ++ * set rep[0] = rep_offset1 and rep[1] = offsetSaved1. ++ */ ++ offsetSaved2 = ((offsetSaved1 != 0) && (rep_offset1 != 0)) ? offsetSaved1 : offsetSaved2; ++ + /* save reps for next block */ +- rep[0] = rep_offset1 ? rep_offset1 : offsetSaved; +- rep[1] = rep_offset2 ? rep_offset2 : offsetSaved; ++ rep[0] = rep_offset1 ? rep_offset1 : offsetSaved1; ++ rep[1] = rep_offset2 ? rep_offset2 : offsetSaved2; + + /* Return the last literals size */ + return (size_t)(iend - anchor); +@@ -267,7 +365,7 @@ ZSTD_compressBlock_fast_noDict_generic( + match0 = base + idx; + rep_offset2 = rep_offset1; + rep_offset1 = (U32)(ip0-match0); +- offcode = STORE_OFFSET(rep_offset1); ++ offcode = OFFSET_TO_OFFBASE(rep_offset1); + mLength = 4; + + /* Count the backwards match length. */ +@@ -287,11 +385,6 @@ ZSTD_compressBlock_fast_noDict_generic( + ip0 += mLength; + anchor = ip0; + +- /* write next hash table entry */ +- if (ip1 < ip0) { +- hashTable[hash1] = (U32)(ip1 - base); +- } +- + /* Fill table and check for immediate repcode. */ + if (ip0 <= ilimit) { + /* Fill Table */ +@@ -306,7 +399,7 @@ ZSTD_compressBlock_fast_noDict_generic( + { U32 const tmpOff = rep_offset2; rep_offset2 = rep_offset1; rep_offset1 = tmpOff; } /* swap rep_offset2 <=> rep_offset1 */ + hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = (U32)(ip0-base); + ip0 += rLength; +- ZSTD_storeSeq(seqStore, 0 /*litLen*/, anchor, iend, STORE_REPCODE_1, rLength); ++ ZSTD_storeSeq(seqStore, 0 /*litLen*/, anchor, iend, REPCODE1_TO_OFFBASE, rLength); + anchor = ip0; + continue; /* faster when present (confirmed on gcc-8) ... (?) */ + } } } +@@ -369,6 +462,7 @@ size_t ZSTD_compressBlock_fast( + } + + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_compressBlock_fast_dictMatchState_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize, U32 const mls, U32 const hasStep) +@@ -380,14 +474,14 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( + U32 const stepSize = cParams->targetLength + !(cParams->targetLength); + const BYTE* const base = ms->window.base; + const BYTE* const istart = (const BYTE*)src; +- const BYTE* ip = istart; ++ const BYTE* ip0 = istart; ++ const BYTE* ip1 = ip0 + stepSize; /* we assert below that stepSize >= 1 */ + const BYTE* anchor = istart; + const U32 prefixStartIndex = ms->window.dictLimit; + const BYTE* const prefixStart = base + prefixStartIndex; + const BYTE* const iend = istart + srcSize; + const BYTE* const ilimit = iend - HASH_READ_SIZE; + U32 offset_1=rep[0], offset_2=rep[1]; +- U32 offsetSaved = 0; + + const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_compressionParameters* const dictCParams = &dms->cParams ; +@@ -397,13 +491,13 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( + const BYTE* const dictStart = dictBase + dictStartIndex; + const BYTE* const dictEnd = dms->window.nextSrc; + const U32 dictIndexDelta = prefixStartIndex - (U32)(dictEnd - dictBase); +- const U32 dictAndPrefixLength = (U32)(ip - prefixStart + dictEnd - dictStart); +- const U32 dictHLog = dictCParams->hashLog; ++ const U32 dictAndPrefixLength = (U32)(istart - prefixStart + dictEnd - dictStart); ++ const U32 dictHBits = dictCParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + + /* if a dictionary is still attached, it necessarily means that + * it is within window size. So we just check it. */ + const U32 maxDistance = 1U << cParams->windowLog; +- const U32 endIndex = (U32)((size_t)(ip - base) + srcSize); ++ const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); + assert(endIndex - prefixStartIndex <= maxDistance); + (void)maxDistance; (void)endIndex; /* these variables are not used when assert() is disabled */ + +@@ -413,106 +507,155 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( + * when translating a dict index into a local index */ + assert(prefixStartIndex >= (U32)(dictEnd - dictBase)); + ++ if (ms->prefetchCDictTables) { ++ size_t const hashTableBytes = (((size_t)1) << dictCParams->hashLog) * sizeof(U32); ++ PREFETCH_AREA(dictHashTable, hashTableBytes); ++ } ++ + /* init */ + DEBUGLOG(5, "ZSTD_compressBlock_fast_dictMatchState_generic"); +- ip += (dictAndPrefixLength == 0); ++ ip0 += (dictAndPrefixLength == 0); + /* dictMatchState repCode checks don't currently handle repCode == 0 + * disabling. */ + assert(offset_1 <= dictAndPrefixLength); + assert(offset_2 <= dictAndPrefixLength); + +- /* Main Search Loop */ +- while (ip < ilimit) { /* < instead of <=, because repcode check at (ip+1) */ ++ /* Outer search loop */ ++ assert(stepSize >= 1); ++ while (ip1 <= ilimit) { /* repcode check at (ip0 + 1) is safe because ip0 < ip1 */ + size_t mLength; +- size_t const h = ZSTD_hashPtr(ip, hlog, mls); +- U32 const curr = (U32)(ip-base); +- U32 const matchIndex = hashTable[h]; +- const BYTE* match = base + matchIndex; +- const U32 repIndex = curr + 1 - offset_1; +- const BYTE* repMatch = (repIndex < prefixStartIndex) ? +- dictBase + (repIndex - dictIndexDelta) : +- base + repIndex; +- hashTable[h] = curr; /* update hash table */ +- +- if ( ((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow : ensure repIndex isn't overlapping dict + prefix */ +- && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { +- const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; +- mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixStart) + 4; +- ip++; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); +- } else if ( (matchIndex <= prefixStartIndex) ) { +- size_t const dictHash = ZSTD_hashPtr(ip, dictHLog, mls); +- U32 const dictMatchIndex = dictHashTable[dictHash]; +- const BYTE* dictMatch = dictBase + dictMatchIndex; +- if (dictMatchIndex <= dictStartIndex || +- MEM_read32(dictMatch) != MEM_read32(ip)) { +- assert(stepSize >= 1); +- ip += ((ip-anchor) >> kSearchStrength) + stepSize; +- continue; +- } else { +- /* found a dict match */ +- U32 const offset = (U32)(curr-dictMatchIndex-dictIndexDelta); +- mLength = ZSTD_count_2segments(ip+4, dictMatch+4, iend, dictEnd, prefixStart) + 4; +- while (((ip>anchor) & (dictMatch>dictStart)) +- && (ip[-1] == dictMatch[-1])) { +- ip--; dictMatch--; mLength++; ++ size_t hash0 = ZSTD_hashPtr(ip0, hlog, mls); ++ ++ size_t const dictHashAndTag0 = ZSTD_hashPtr(ip0, dictHBits, mls); ++ U32 dictMatchIndexAndTag = dictHashTable[dictHashAndTag0 >> ZSTD_SHORT_CACHE_TAG_BITS]; ++ int dictTagsMatch = ZSTD_comparePackedTags(dictMatchIndexAndTag, dictHashAndTag0); ++ ++ U32 matchIndex = hashTable[hash0]; ++ U32 curr = (U32)(ip0 - base); ++ size_t step = stepSize; ++ const size_t kStepIncr = 1 << kSearchStrength; ++ const BYTE* nextStep = ip0 + kStepIncr; ++ ++ /* Inner search loop */ ++ while (1) { ++ const BYTE* match = base + matchIndex; ++ const U32 repIndex = curr + 1 - offset_1; ++ const BYTE* repMatch = (repIndex < prefixStartIndex) ? ++ dictBase + (repIndex - dictIndexDelta) : ++ base + repIndex; ++ const size_t hash1 = ZSTD_hashPtr(ip1, hlog, mls); ++ size_t const dictHashAndTag1 = ZSTD_hashPtr(ip1, dictHBits, mls); ++ hashTable[hash0] = curr; /* update hash table */ ++ ++ if (((U32) ((prefixStartIndex - 1) - repIndex) >= ++ 3) /* intentional underflow : ensure repIndex isn't overlapping dict + prefix */ ++ && (MEM_read32(repMatch) == MEM_read32(ip0 + 1))) { ++ const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; ++ mLength = ZSTD_count_2segments(ip0 + 1 + 4, repMatch + 4, iend, repMatchEnd, prefixStart) + 4; ++ ip0++; ++ ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); ++ break; ++ } ++ ++ if (dictTagsMatch) { ++ /* Found a possible dict match */ ++ const U32 dictMatchIndex = dictMatchIndexAndTag >> ZSTD_SHORT_CACHE_TAG_BITS; ++ const BYTE* dictMatch = dictBase + dictMatchIndex; ++ if (dictMatchIndex > dictStartIndex && ++ MEM_read32(dictMatch) == MEM_read32(ip0)) { ++ /* To replicate extDict parse behavior, we only use dict matches when the normal matchIndex is invalid */ ++ if (matchIndex <= prefixStartIndex) { ++ U32 const offset = (U32) (curr - dictMatchIndex - dictIndexDelta); ++ mLength = ZSTD_count_2segments(ip0 + 4, dictMatch + 4, iend, dictEnd, prefixStart) + 4; ++ while (((ip0 > anchor) & (dictMatch > dictStart)) ++ && (ip0[-1] == dictMatch[-1])) { ++ ip0--; ++ dictMatch--; ++ mLength++; ++ } /* catch up */ ++ offset_2 = offset_1; ++ offset_1 = offset; ++ ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); ++ break; ++ } ++ } ++ } ++ ++ if (matchIndex > prefixStartIndex && MEM_read32(match) == MEM_read32(ip0)) { ++ /* found a regular match */ ++ U32 const offset = (U32) (ip0 - match); ++ mLength = ZSTD_count(ip0 + 4, match + 4, iend) + 4; ++ while (((ip0 > anchor) & (match > prefixStart)) ++ && (ip0[-1] == match[-1])) { ++ ip0--; ++ match--; ++ mLength++; + } /* catch up */ + offset_2 = offset_1; + offset_1 = offset; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); ++ ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); ++ break; + } +- } else if (MEM_read32(match) != MEM_read32(ip)) { +- /* it's not a match, and we're not going to check the dictionary */ +- assert(stepSize >= 1); +- ip += ((ip-anchor) >> kSearchStrength) + stepSize; +- continue; +- } else { +- /* found a regular match */ +- U32 const offset = (U32)(ip-match); +- mLength = ZSTD_count(ip+4, match+4, iend) + 4; +- while (((ip>anchor) & (match>prefixStart)) +- && (ip[-1] == match[-1])) { ip--; match--; mLength++; } /* catch up */ +- offset_2 = offset_1; +- offset_1 = offset; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); +- } ++ ++ /* Prepare for next iteration */ ++ dictMatchIndexAndTag = dictHashTable[dictHashAndTag1 >> ZSTD_SHORT_CACHE_TAG_BITS]; ++ dictTagsMatch = ZSTD_comparePackedTags(dictMatchIndexAndTag, dictHashAndTag1); ++ matchIndex = hashTable[hash1]; ++ ++ if (ip1 >= nextStep) { ++ step++; ++ nextStep += kStepIncr; ++ } ++ ip0 = ip1; ++ ip1 = ip1 + step; ++ if (ip1 > ilimit) goto _cleanup; ++ ++ curr = (U32)(ip0 - base); ++ hash0 = hash1; ++ } /* end inner search loop */ + + /* match found */ +- ip += mLength; +- anchor = ip; ++ assert(mLength); ++ ip0 += mLength; ++ anchor = ip0; + +- if (ip <= ilimit) { ++ if (ip0 <= ilimit) { + /* Fill Table */ + assert(base+curr+2 > istart); /* check base overflow */ + hashTable[ZSTD_hashPtr(base+curr+2, hlog, mls)] = curr+2; /* here because curr+2 could be > iend-8 */ +- hashTable[ZSTD_hashPtr(ip-2, hlog, mls)] = (U32)(ip-2-base); ++ hashTable[ZSTD_hashPtr(ip0-2, hlog, mls)] = (U32)(ip0-2-base); + + /* check immediate repcode */ +- while (ip <= ilimit) { +- U32 const current2 = (U32)(ip-base); ++ while (ip0 <= ilimit) { ++ U32 const current2 = (U32)(ip0-base); + U32 const repIndex2 = current2 - offset_2; + const BYTE* repMatch2 = repIndex2 < prefixStartIndex ? + dictBase - dictIndexDelta + repIndex2 : + base + repIndex2; + if ( ((U32)((prefixStartIndex-1) - (U32)repIndex2) >= 3 /* intentional overflow */) +- && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { ++ && (MEM_read32(repMatch2) == MEM_read32(ip0))) { + const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; +- size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; ++ size_t const repLength2 = ZSTD_count_2segments(ip0+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; + U32 tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); +- hashTable[ZSTD_hashPtr(ip, hlog, mls)] = current2; +- ip += repLength2; +- anchor = ip; ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); ++ hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = current2; ++ ip0 += repLength2; ++ anchor = ip0; + continue; + } + break; + } + } ++ ++ /* Prepare for next iteration */ ++ assert(ip0 == anchor); ++ ip1 = ip0 + stepSize; + } + ++_cleanup: + /* save reps for next block */ +- rep[0] = offset_1 ? offset_1 : offsetSaved; +- rep[1] = offset_2 ? offset_2 : offsetSaved; ++ rep[0] = offset_1; ++ rep[1] = offset_2; + + /* Return the last literals size */ + return (size_t)(iend - anchor); +@@ -545,7 +688,9 @@ size_t ZSTD_compressBlock_fast_dictMatchState( + } + + +-static size_t ZSTD_compressBlock_fast_extDict_generic( ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_compressBlock_fast_extDict_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize, U32 const mls, U32 const hasStep) + { +@@ -553,11 +698,10 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( + U32* const hashTable = ms->hashTable; + U32 const hlog = cParams->hashLog; + /* support stepSize of 0 */ +- U32 const stepSize = cParams->targetLength + !(cParams->targetLength); ++ size_t const stepSize = cParams->targetLength + !(cParams->targetLength) + 1; + const BYTE* const base = ms->window.base; + const BYTE* const dictBase = ms->window.dictBase; + const BYTE* const istart = (const BYTE*)src; +- const BYTE* ip = istart; + const BYTE* anchor = istart; + const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); + const U32 lowLimit = ZSTD_getLowestMatchIndex(ms, endIndex, cParams->windowLog); +@@ -570,6 +714,28 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( + const BYTE* const iend = istart + srcSize; + const BYTE* const ilimit = iend - 8; + U32 offset_1=rep[0], offset_2=rep[1]; ++ U32 offsetSaved1 = 0, offsetSaved2 = 0; ++ ++ const BYTE* ip0 = istart; ++ const BYTE* ip1; ++ const BYTE* ip2; ++ const BYTE* ip3; ++ U32 current0; ++ ++ ++ size_t hash0; /* hash for ip0 */ ++ size_t hash1; /* hash for ip1 */ ++ U32 idx; /* match idx for ip0 */ ++ const BYTE* idxBase; /* base pointer for idx */ ++ ++ U32 offcode; ++ const BYTE* match0; ++ size_t mLength; ++ const BYTE* matchEnd = 0; /* initialize to avoid warning, assert != 0 later */ ++ ++ size_t step; ++ const BYTE* nextStep; ++ const size_t kStepIncr = (1 << (kSearchStrength - 1)); + + (void)hasStep; /* not currently specialized on whether it's accelerated */ + +@@ -579,75 +745,202 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( + if (prefixStartIndex == dictStartIndex) + return ZSTD_compressBlock_fast(ms, seqStore, rep, src, srcSize); + +- /* Search Loop */ +- while (ip < ilimit) { /* < instead of <=, because (ip+1) */ +- const size_t h = ZSTD_hashPtr(ip, hlog, mls); +- const U32 matchIndex = hashTable[h]; +- const BYTE* const matchBase = matchIndex < prefixStartIndex ? dictBase : base; +- const BYTE* match = matchBase + matchIndex; +- const U32 curr = (U32)(ip-base); +- const U32 repIndex = curr + 1 - offset_1; +- const BYTE* const repBase = repIndex < prefixStartIndex ? dictBase : base; +- const BYTE* const repMatch = repBase + repIndex; +- hashTable[h] = curr; /* update hash table */ +- DEBUGLOG(7, "offset_1 = %u , curr = %u", offset_1, curr); +- +- if ( ( ((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow */ +- & (offset_1 <= curr+1 - dictStartIndex) ) /* note: we are searching at curr+1 */ +- && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { +- const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; +- size_t const rLength = ZSTD_count_2segments(ip+1 +4, repMatch +4, iend, repMatchEnd, prefixStart) + 4; +- ip++; +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, rLength); +- ip += rLength; +- anchor = ip; +- } else { +- if ( (matchIndex < dictStartIndex) || +- (MEM_read32(match) != MEM_read32(ip)) ) { +- assert(stepSize >= 1); +- ip += ((ip-anchor) >> kSearchStrength) + stepSize; +- continue; ++ { U32 const curr = (U32)(ip0 - base); ++ U32 const maxRep = curr - dictStartIndex; ++ if (offset_2 >= maxRep) offsetSaved2 = offset_2, offset_2 = 0; ++ if (offset_1 >= maxRep) offsetSaved1 = offset_1, offset_1 = 0; ++ } ++ ++ /* start each op */ ++_start: /* Requires: ip0 */ ++ ++ step = stepSize; ++ nextStep = ip0 + kStepIncr; ++ ++ /* calculate positions, ip0 - anchor == 0, so we skip step calc */ ++ ip1 = ip0 + 1; ++ ip2 = ip0 + step; ++ ip3 = ip2 + 1; ++ ++ if (ip3 >= ilimit) { ++ goto _cleanup; ++ } ++ ++ hash0 = ZSTD_hashPtr(ip0, hlog, mls); ++ hash1 = ZSTD_hashPtr(ip1, hlog, mls); ++ ++ idx = hashTable[hash0]; ++ idxBase = idx < prefixStartIndex ? dictBase : base; ++ ++ do { ++ { /* load repcode match for ip[2] */ ++ U32 const current2 = (U32)(ip2 - base); ++ U32 const repIndex = current2 - offset_1; ++ const BYTE* const repBase = repIndex < prefixStartIndex ? dictBase : base; ++ U32 rval; ++ if ( ((U32)(prefixStartIndex - repIndex) >= 4) /* intentional underflow */ ++ & (offset_1 > 0) ) { ++ rval = MEM_read32(repBase + repIndex); ++ } else { ++ rval = MEM_read32(ip2) ^ 1; /* guaranteed to not match. */ + } +- { const BYTE* const matchEnd = matchIndex < prefixStartIndex ? dictEnd : iend; +- const BYTE* const lowMatchPtr = matchIndex < prefixStartIndex ? dictStart : prefixStart; +- U32 const offset = curr - matchIndex; +- size_t mLength = ZSTD_count_2segments(ip+4, match+4, iend, matchEnd, prefixStart) + 4; +- while (((ip>anchor) & (match>lowMatchPtr)) && (ip[-1] == match[-1])) { ip--; match--; mLength++; } /* catch up */ +- offset_2 = offset_1; offset_1 = offset; /* update offset history */ +- ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); +- ip += mLength; +- anchor = ip; ++ ++ /* write back hash table entry */ ++ current0 = (U32)(ip0 - base); ++ hashTable[hash0] = current0; ++ ++ /* check repcode at ip[2] */ ++ if (MEM_read32(ip2) == rval) { ++ ip0 = ip2; ++ match0 = repBase + repIndex; ++ matchEnd = repIndex < prefixStartIndex ? dictEnd : iend; ++ assert((match0 != prefixStart) & (match0 != dictStart)); ++ mLength = ip0[-1] == match0[-1]; ++ ip0 -= mLength; ++ match0 -= mLength; ++ offcode = REPCODE1_TO_OFFBASE; ++ mLength += 4; ++ goto _match; + } } + +- if (ip <= ilimit) { +- /* Fill Table */ +- hashTable[ZSTD_hashPtr(base+curr+2, hlog, mls)] = curr+2; +- hashTable[ZSTD_hashPtr(ip-2, hlog, mls)] = (U32)(ip-2-base); +- /* check immediate repcode */ +- while (ip <= ilimit) { +- U32 const current2 = (U32)(ip-base); +- U32 const repIndex2 = current2 - offset_2; +- const BYTE* const repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; +- if ( (((U32)((prefixStartIndex-1) - repIndex2) >= 3) & (offset_2 <= curr - dictStartIndex)) /* intentional overflow */ +- && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { +- const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; +- size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; +- { U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; } /* swap offset_2 <=> offset_1 */ +- ZSTD_storeSeq(seqStore, 0 /*litlen*/, anchor, iend, STORE_REPCODE_1, repLength2); +- hashTable[ZSTD_hashPtr(ip, hlog, mls)] = current2; +- ip += repLength2; +- anchor = ip; +- continue; +- } +- break; +- } } } ++ { /* load match for ip[0] */ ++ U32 const mval = idx >= dictStartIndex ? ++ MEM_read32(idxBase + idx) : ++ MEM_read32(ip0) ^ 1; /* guaranteed not to match */ ++ ++ /* check match at ip[0] */ ++ if (MEM_read32(ip0) == mval) { ++ /* found a match! */ ++ goto _offset; ++ } } ++ ++ /* lookup ip[1] */ ++ idx = hashTable[hash1]; ++ idxBase = idx < prefixStartIndex ? dictBase : base; ++ ++ /* hash ip[2] */ ++ hash0 = hash1; ++ hash1 = ZSTD_hashPtr(ip2, hlog, mls); ++ ++ /* advance to next positions */ ++ ip0 = ip1; ++ ip1 = ip2; ++ ip2 = ip3; ++ ++ /* write back hash table entry */ ++ current0 = (U32)(ip0 - base); ++ hashTable[hash0] = current0; ++ ++ { /* load match for ip[0] */ ++ U32 const mval = idx >= dictStartIndex ? ++ MEM_read32(idxBase + idx) : ++ MEM_read32(ip0) ^ 1; /* guaranteed not to match */ ++ ++ /* check match at ip[0] */ ++ if (MEM_read32(ip0) == mval) { ++ /* found a match! */ ++ goto _offset; ++ } } ++ ++ /* lookup ip[1] */ ++ idx = hashTable[hash1]; ++ idxBase = idx < prefixStartIndex ? dictBase : base; ++ ++ /* hash ip[2] */ ++ hash0 = hash1; ++ hash1 = ZSTD_hashPtr(ip2, hlog, mls); ++ ++ /* advance to next positions */ ++ ip0 = ip1; ++ ip1 = ip2; ++ ip2 = ip0 + step; ++ ip3 = ip1 + step; ++ ++ /* calculate step */ ++ if (ip2 >= nextStep) { ++ step++; ++ PREFETCH_L1(ip1 + 64); ++ PREFETCH_L1(ip1 + 128); ++ nextStep += kStepIncr; ++ } ++ } while (ip3 < ilimit); ++ ++_cleanup: ++ /* Note that there are probably still a couple positions we could search. ++ * However, it seems to be a meaningful performance hit to try to search ++ * them. So let's not. */ ++ ++ /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), ++ * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ ++ offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; + + /* save reps for next block */ +- rep[0] = offset_1; +- rep[1] = offset_2; ++ rep[0] = offset_1 ? offset_1 : offsetSaved1; ++ rep[1] = offset_2 ? offset_2 : offsetSaved2; + + /* Return the last literals size */ + return (size_t)(iend - anchor); ++ ++_offset: /* Requires: ip0, idx, idxBase */ ++ ++ /* Compute the offset code. */ ++ { U32 const offset = current0 - idx; ++ const BYTE* const lowMatchPtr = idx < prefixStartIndex ? dictStart : prefixStart; ++ matchEnd = idx < prefixStartIndex ? dictEnd : iend; ++ match0 = idxBase + idx; ++ offset_2 = offset_1; ++ offset_1 = offset; ++ offcode = OFFSET_TO_OFFBASE(offset); ++ mLength = 4; ++ ++ /* Count the backwards match length. */ ++ while (((ip0>anchor) & (match0>lowMatchPtr)) && (ip0[-1] == match0[-1])) { ++ ip0--; ++ match0--; ++ mLength++; ++ } } ++ ++_match: /* Requires: ip0, match0, offcode, matchEnd */ ++ ++ /* Count the forward length. */ ++ assert(matchEnd != 0); ++ mLength += ZSTD_count_2segments(ip0 + mLength, match0 + mLength, iend, matchEnd, prefixStart); ++ ++ ZSTD_storeSeq(seqStore, (size_t)(ip0 - anchor), anchor, iend, offcode, mLength); ++ ++ ip0 += mLength; ++ anchor = ip0; ++ ++ /* write next hash table entry */ ++ if (ip1 < ip0) { ++ hashTable[hash1] = (U32)(ip1 - base); ++ } ++ ++ /* Fill table and check for immediate repcode. */ ++ if (ip0 <= ilimit) { ++ /* Fill Table */ ++ assert(base+current0+2 > istart); /* check base overflow */ ++ hashTable[ZSTD_hashPtr(base+current0+2, hlog, mls)] = current0+2; /* here because current+2 could be > iend-8 */ ++ hashTable[ZSTD_hashPtr(ip0-2, hlog, mls)] = (U32)(ip0-2-base); ++ ++ while (ip0 <= ilimit) { ++ U32 const repIndex2 = (U32)(ip0-base) - offset_2; ++ const BYTE* const repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; ++ if ( (((U32)((prefixStartIndex-1) - repIndex2) >= 3) & (offset_2 > 0)) /* intentional underflow */ ++ && (MEM_read32(repMatch2) == MEM_read32(ip0)) ) { ++ const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; ++ size_t const repLength2 = ZSTD_count_2segments(ip0+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; ++ { U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; } /* swap offset_2 <=> offset_1 */ ++ ZSTD_storeSeq(seqStore, 0 /*litlen*/, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); ++ hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = (U32)(ip0-base); ++ ip0 += repLength2; ++ anchor = ip0; ++ continue; ++ } ++ break; ++ } } ++ ++ goto _start; + } + + ZSTD_GEN_FAST_FN(extDict, 4, 0) +@@ -660,6 +953,7 @@ size_t ZSTD_compressBlock_fast_extDict( + void const* src, size_t srcSize) + { + U32 const mls = ms->cParams.minMatch; ++ assert(ms->dictMatchState == NULL); + switch(mls) + { + default: /* includes case 3 */ +diff --git a/lib/zstd/compress/zstd_fast.h b/lib/zstd/compress/zstd_fast.h +index fddc2f532d21..e64d9e1b2d39 100644 +--- a/lib/zstd/compress/zstd_fast.h ++++ b/lib/zstd/compress/zstd_fast.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -16,7 +17,8 @@ + #include "zstd_compress_internal.h" + + void ZSTD_fillHashTable(ZSTD_matchState_t* ms, +- void const* end, ZSTD_dictTableLoadMethod_e dtlm); ++ void const* end, ZSTD_dictTableLoadMethod_e dtlm, ++ ZSTD_tableFillPurpose_e tfp); + size_t ZSTD_compressBlock_fast( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +diff --git a/lib/zstd/compress/zstd_lazy.c b/lib/zstd/compress/zstd_lazy.c +index 0298a01a7504..3e88d8a1a136 100644 +--- a/lib/zstd/compress/zstd_lazy.c ++++ b/lib/zstd/compress/zstd_lazy.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -10,14 +11,23 @@ + + #include "zstd_compress_internal.h" + #include "zstd_lazy.h" ++#include "../common/bits.h" /* ZSTD_countTrailingZeros64 */ ++ ++#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) ++ ++#define kLazySkippingStep 8 + + + /*-************************************* + * Binary Tree search + ***************************************/ + +-static void +-ZSTD_updateDUBT(ZSTD_matchState_t* ms, ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_updateDUBT(ZSTD_matchState_t* ms, + const BYTE* ip, const BYTE* iend, + U32 mls) + { +@@ -60,8 +70,9 @@ ZSTD_updateDUBT(ZSTD_matchState_t* ms, + * sort one already inserted but unsorted position + * assumption : curr >= btlow == (curr - btmask) + * doesn't fail */ +-static void +-ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, + U32 curr, const BYTE* inputEnd, + U32 nbCompares, U32 btLow, + const ZSTD_dictMode_e dictMode) +@@ -149,8 +160,9 @@ ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, + } + + +-static size_t +-ZSTD_DUBT_findBetterDictMatch ( ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_DUBT_findBetterDictMatch ( + const ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iend, + size_t* offsetPtr, +@@ -197,8 +209,8 @@ ZSTD_DUBT_findBetterDictMatch ( + U32 matchIndex = dictMatchIndex + dictIndexDelta; + if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr-matchIndex+1) - ZSTD_highbit32((U32)offsetPtr[0]+1)) ) { + DEBUGLOG(9, "ZSTD_DUBT_findBetterDictMatch(%u) : found better match length %u -> %u and offsetCode %u -> %u (dictMatchIndex %u, matchIndex %u)", +- curr, (U32)bestLength, (U32)matchLength, (U32)*offsetPtr, STORE_OFFSET(curr - matchIndex), dictMatchIndex, matchIndex); +- bestLength = matchLength, *offsetPtr = STORE_OFFSET(curr - matchIndex); ++ curr, (U32)bestLength, (U32)matchLength, (U32)*offsetPtr, OFFSET_TO_OFFBASE(curr - matchIndex), dictMatchIndex, matchIndex); ++ bestLength = matchLength, *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); + } + if (ip+matchLength == iend) { /* reached end of input : ip[matchLength] is not valid, no way to know if it's larger or smaller than match */ + break; /* drop, to guarantee consistency (miss a little bit of compression) */ +@@ -218,7 +230,7 @@ ZSTD_DUBT_findBetterDictMatch ( + } + + if (bestLength >= MINMATCH) { +- U32 const mIndex = curr - (U32)STORED_OFFSET(*offsetPtr); (void)mIndex; ++ U32 const mIndex = curr - (U32)OFFBASE_TO_OFFSET(*offsetPtr); (void)mIndex; + DEBUGLOG(8, "ZSTD_DUBT_findBetterDictMatch(%u) : found match of length %u and offsetCode %u (pos %u)", + curr, (U32)bestLength, (U32)*offsetPtr, mIndex); + } +@@ -227,10 +239,11 @@ ZSTD_DUBT_findBetterDictMatch ( + } + + +-static size_t +-ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iend, +- size_t* offsetPtr, ++ size_t* offBasePtr, + U32 const mls, + const ZSTD_dictMode_e dictMode) + { +@@ -327,8 +340,8 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, + if (matchLength > bestLength) { + if (matchLength > matchEndIdx - matchIndex) + matchEndIdx = matchIndex + (U32)matchLength; +- if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr-matchIndex+1) - ZSTD_highbit32((U32)offsetPtr[0]+1)) ) +- bestLength = matchLength, *offsetPtr = STORE_OFFSET(curr - matchIndex); ++ if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr - matchIndex + 1) - ZSTD_highbit32((U32)*offBasePtr)) ) ++ bestLength = matchLength, *offBasePtr = OFFSET_TO_OFFBASE(curr - matchIndex); + if (ip+matchLength == iend) { /* equal : no way to know if inf or sup */ + if (dictMode == ZSTD_dictMatchState) { + nbCompares = 0; /* in addition to avoiding checking any +@@ -361,16 +374,16 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, + if (dictMode == ZSTD_dictMatchState && nbCompares) { + bestLength = ZSTD_DUBT_findBetterDictMatch( + ms, ip, iend, +- offsetPtr, bestLength, nbCompares, ++ offBasePtr, bestLength, nbCompares, + mls, dictMode); + } + + assert(matchEndIdx > curr+8); /* ensure nextToUpdate is increased */ + ms->nextToUpdate = matchEndIdx - 8; /* skip repetitive patterns */ + if (bestLength >= MINMATCH) { +- U32 const mIndex = curr - (U32)STORED_OFFSET(*offsetPtr); (void)mIndex; ++ U32 const mIndex = curr - (U32)OFFBASE_TO_OFFSET(*offBasePtr); (void)mIndex; + DEBUGLOG(8, "ZSTD_DUBT_findBestMatch(%u) : found match of length %u and offsetCode %u (pos %u)", +- curr, (U32)bestLength, (U32)*offsetPtr, mIndex); ++ curr, (U32)bestLength, (U32)*offBasePtr, mIndex); + } + return bestLength; + } +@@ -378,17 +391,18 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, + + + /* ZSTD_BtFindBestMatch() : Tree updater, providing best match */ +-FORCE_INLINE_TEMPLATE size_t +-ZSTD_BtFindBestMatch( ZSTD_matchState_t* ms, ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_BtFindBestMatch( ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iLimit, +- size_t* offsetPtr, ++ size_t* offBasePtr, + const U32 mls /* template */, + const ZSTD_dictMode_e dictMode) + { + DEBUGLOG(7, "ZSTD_BtFindBestMatch"); + if (ip < ms->window.base + ms->nextToUpdate) return 0; /* skipped area */ + ZSTD_updateDUBT(ms, ip, iLimit, mls); +- return ZSTD_DUBT_findBestMatch(ms, ip, iLimit, offsetPtr, mls, dictMode); ++ return ZSTD_DUBT_findBestMatch(ms, ip, iLimit, offBasePtr, mls, dictMode); + } + + /* ********************************* +@@ -561,7 +575,7 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb + /* save best solution */ + if (currentMl > ml) { + ml = currentMl; +- *offsetPtr = STORE_OFFSET(curr - (matchIndex + ddsIndexDelta)); ++ *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + ddsIndexDelta)); + if (ip+currentMl == iLimit) { + /* best possible, avoids read overflow on next attempt */ + return ml; +@@ -598,7 +612,7 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb + /* save best solution */ + if (currentMl > ml) { + ml = currentMl; +- *offsetPtr = STORE_OFFSET(curr - (matchIndex + ddsIndexDelta)); ++ *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + ddsIndexDelta)); + if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ + } + } +@@ -614,10 +628,12 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb + + /* Update chains up to ip (excluded) + Assumption : always within prefix (i.e. not within extDict) */ +-FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_insertAndFindFirstIndex_internal( + ZSTD_matchState_t* ms, + const ZSTD_compressionParameters* const cParams, +- const BYTE* ip, U32 const mls) ++ const BYTE* ip, U32 const mls, U32 const lazySkipping) + { + U32* const hashTable = ms->hashTable; + const U32 hashLog = cParams->hashLog; +@@ -632,6 +648,9 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( + NEXT_IN_CHAIN(idx, chainMask) = hashTable[h]; + hashTable[h] = idx; + idx++; ++ /* Stop inserting every position when in the lazy skipping mode. */ ++ if (lazySkipping) ++ break; + } + + ms->nextToUpdate = target; +@@ -640,11 +659,12 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( + + U32 ZSTD_insertAndFindFirstIndex(ZSTD_matchState_t* ms, const BYTE* ip) { + const ZSTD_compressionParameters* const cParams = &ms->cParams; +- return ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, ms->cParams.minMatch); ++ return ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, ms->cParams.minMatch, /* lazySkipping*/ 0); + } + + /* inlining is important to hardwire a hot branch (template emulation) */ + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_HcFindBestMatch( + ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iLimit, +@@ -684,14 +704,15 @@ size_t ZSTD_HcFindBestMatch( + } + + /* HC4 match finder */ +- matchIndex = ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, mls); ++ matchIndex = ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, mls, ms->lazySkipping); + + for ( ; (matchIndex>=lowLimit) & (nbAttempts>0) ; nbAttempts--) { + size_t currentMl=0; + if ((dictMode != ZSTD_extDict) || matchIndex >= dictLimit) { + const BYTE* const match = base + matchIndex; + assert(matchIndex >= dictLimit); /* ensures this is true if dictMode != ZSTD_extDict */ +- if (match[ml] == ip[ml]) /* potentially better */ ++ /* read 4B starting from (match + ml + 1 - sizeof(U32)) */ ++ if (MEM_read32(match + ml - 3) == MEM_read32(ip + ml - 3)) /* potentially better */ + currentMl = ZSTD_count(ip, match, iLimit); + } else { + const BYTE* const match = dictBase + matchIndex; +@@ -703,7 +724,7 @@ size_t ZSTD_HcFindBestMatch( + /* save best solution */ + if (currentMl > ml) { + ml = currentMl; +- *offsetPtr = STORE_OFFSET(curr - matchIndex); ++ *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); + if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ + } + +@@ -739,7 +760,7 @@ size_t ZSTD_HcFindBestMatch( + if (currentMl > ml) { + ml = currentMl; + assert(curr > matchIndex + dmsIndexDelta); +- *offsetPtr = STORE_OFFSET(curr - (matchIndex + dmsIndexDelta)); ++ *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + dmsIndexDelta)); + if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ + } + +@@ -756,8 +777,6 @@ size_t ZSTD_HcFindBestMatch( + * (SIMD) Row-based matchfinder + ***********************************/ + /* Constants for row-based hash */ +-#define ZSTD_ROW_HASH_TAG_OFFSET 16 /* byte offset of hashes in the match state's tagTable from the beginning of a row */ +-#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */ + #define ZSTD_ROW_HASH_TAG_MASK ((1u << ZSTD_ROW_HASH_TAG_BITS) - 1) + #define ZSTD_ROW_HASH_MAX_ENTRIES 64 /* absolute maximum number of entries per row, for all configurations */ + +@@ -769,64 +788,19 @@ typedef U64 ZSTD_VecMask; /* Clarifies when we are interacting with a U64 repr + * Starting from the LSB, returns the idx of the next non-zero bit. + * Basically counting the nb of trailing zeroes. + */ +-static U32 ZSTD_VecMask_next(ZSTD_VecMask val) { +- assert(val != 0); +-# if (defined(__GNUC__) && ((__GNUC__ > 3) || ((__GNUC__ == 3) && (__GNUC_MINOR__ >= 4)))) +- if (sizeof(size_t) == 4) { +- U32 mostSignificantWord = (U32)(val >> 32); +- U32 leastSignificantWord = (U32)val; +- if (leastSignificantWord == 0) { +- return 32 + (U32)__builtin_ctz(mostSignificantWord); +- } else { +- return (U32)__builtin_ctz(leastSignificantWord); +- } +- } else { +- return (U32)__builtin_ctzll(val); +- } +-# else +- /* Software ctz version: http://aggregate.org/MAGIC/#Trailing%20Zero%20Count +- * and: https://stackoverflow.com/questions/2709430/count-number-of-bits-in-a-64-bit-long-big-integer +- */ +- val = ~val & (val - 1ULL); /* Lowest set bit mask */ +- val = val - ((val >> 1) & 0x5555555555555555); +- val = (val & 0x3333333333333333ULL) + ((val >> 2) & 0x3333333333333333ULL); +- return (U32)((((val + (val >> 4)) & 0xF0F0F0F0F0F0F0FULL) * 0x101010101010101ULL) >> 56); +-# endif +-} +- +-/* ZSTD_rotateRight_*(): +- * Rotates a bitfield to the right by "count" bits. +- * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts +- */ +-FORCE_INLINE_TEMPLATE +-U64 ZSTD_rotateRight_U64(U64 const value, U32 count) { +- assert(count < 64); +- count &= 0x3F; /* for fickle pattern recognition */ +- return (value >> count) | (U64)(value << ((0U - count) & 0x3F)); +-} +- +-FORCE_INLINE_TEMPLATE +-U32 ZSTD_rotateRight_U32(U32 const value, U32 count) { +- assert(count < 32); +- count &= 0x1F; /* for fickle pattern recognition */ +- return (value >> count) | (U32)(value << ((0U - count) & 0x1F)); +-} +- +-FORCE_INLINE_TEMPLATE +-U16 ZSTD_rotateRight_U16(U16 const value, U32 count) { +- assert(count < 16); +- count &= 0x0F; /* for fickle pattern recognition */ +- return (value >> count) | (U16)(value << ((0U - count) & 0x0F)); ++MEM_STATIC U32 ZSTD_VecMask_next(ZSTD_VecMask val) { ++ return ZSTD_countTrailingZeros64(val); + } + + /* ZSTD_row_nextIndex(): + * Returns the next index to insert at within a tagTable row, and updates the "head" +- * value to reflect the update. Essentially cycles backwards from [0, {entries per row}) ++ * value to reflect the update. Essentially cycles backwards from [1, {entries per row}) + */ + FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextIndex(BYTE* const tagRow, U32 const rowMask) { +- U32 const next = (*tagRow - 1) & rowMask; +- *tagRow = (BYTE)next; +- return next; ++ U32 next = (*tagRow-1) & rowMask; ++ next += (next == 0) ? rowMask : 0; /* skip first position */ ++ *tagRow = (BYTE)next; ++ return next; + } + + /* ZSTD_isAligned(): +@@ -840,7 +814,7 @@ MEM_STATIC int ZSTD_isAligned(void const* ptr, size_t align) { + /* ZSTD_row_prefetch(): + * Performs prefetching for the hashTable and tagTable at a given row. + */ +-FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, U16 const* tagTable, U32 const relRow, U32 const rowLog) { ++FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, BYTE const* tagTable, U32 const relRow, U32 const rowLog) { + PREFETCH_L1(hashTable + relRow); + if (rowLog >= 5) { + PREFETCH_L1(hashTable + relRow + 16); +@@ -859,18 +833,20 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, U16 const* ta + * Fill up the hash cache starting at idx, prefetching up to ZSTD_ROW_HASH_CACHE_SIZE entries, + * but not beyond iLimit. + */ +-FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const BYTE* base, ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const BYTE* base, + U32 const rowLog, U32 const mls, + U32 idx, const BYTE* const iLimit) + { + U32 const* const hashTable = ms->hashTable; +- U16 const* const tagTable = ms->tagTable; ++ BYTE const* const tagTable = ms->tagTable; + U32 const hashLog = ms->rowHashLog; + U32 const maxElemsToPrefetch = (base + idx) > iLimit ? 0 : (U32)(iLimit - (base + idx) + 1); + U32 const lim = idx + MIN(ZSTD_ROW_HASH_CACHE_SIZE, maxElemsToPrefetch); + + for (; idx < lim; ++idx) { +- U32 const hash = (U32)ZSTD_hashPtr(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); ++ U32 const hash = (U32)ZSTD_hashPtrSalted(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt); + U32 const row = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; + ZSTD_row_prefetch(hashTable, tagTable, row, rowLog); + ms->hashCache[idx & ZSTD_ROW_HASH_CACHE_MASK] = hash; +@@ -885,12 +861,15 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const B + * Returns the hash of base + idx, and replaces the hash in the hash cache with the byte at + * base + idx + ZSTD_ROW_HASH_CACHE_SIZE. Also prefetches the appropriate rows from hashTable and tagTable. + */ +-FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable, +- U16 const* tagTable, BYTE const* base, ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable, ++ BYTE const* tagTable, BYTE const* base, + U32 idx, U32 const hashLog, +- U32 const rowLog, U32 const mls) ++ U32 const rowLog, U32 const mls, ++ U64 const hashSalt) + { +- U32 const newHash = (U32)ZSTD_hashPtr(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); ++ U32 const newHash = (U32)ZSTD_hashPtrSalted(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt); + U32 const row = (newHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; + ZSTD_row_prefetch(hashTable, tagTable, row, rowLog); + { U32 const hash = cache[idx & ZSTD_ROW_HASH_CACHE_MASK]; +@@ -902,28 +881,29 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTab + /* ZSTD_row_update_internalImpl(): + * Updates the hash table with positions starting from updateStartIdx until updateEndIdx. + */ +-FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, +- U32 updateStartIdx, U32 const updateEndIdx, +- U32 const mls, U32 const rowLog, +- U32 const rowMask, U32 const useCache) ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, ++ U32 updateStartIdx, U32 const updateEndIdx, ++ U32 const mls, U32 const rowLog, ++ U32 const rowMask, U32 const useCache) + { + U32* const hashTable = ms->hashTable; +- U16* const tagTable = ms->tagTable; ++ BYTE* const tagTable = ms->tagTable; + U32 const hashLog = ms->rowHashLog; + const BYTE* const base = ms->window.base; + + DEBUGLOG(6, "ZSTD_row_update_internalImpl(): updateStartIdx=%u, updateEndIdx=%u", updateStartIdx, updateEndIdx); + for (; updateStartIdx < updateEndIdx; ++updateStartIdx) { +- U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls) +- : (U32)ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); ++ U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls, ms->hashSalt) ++ : (U32)ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt); + U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; + U32* const row = hashTable + relRow; +- BYTE* tagRow = (BYTE*)(tagTable + relRow); /* Though tagTable is laid out as a table of U16, each tag is only 1 byte. +- Explicit cast allows us to get exact desired position within each row */ ++ BYTE* tagRow = tagTable + relRow; + U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask); + +- assert(hash == ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls)); +- ((BYTE*)tagRow)[pos + ZSTD_ROW_HASH_TAG_OFFSET] = hash & ZSTD_ROW_HASH_TAG_MASK; ++ assert(hash == ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt)); ++ tagRow[pos] = hash & ZSTD_ROW_HASH_TAG_MASK; + row[pos] = updateStartIdx; + } + } +@@ -932,9 +912,11 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, + * Inserts the byte at ip into the appropriate position in the hash table, and updates ms->nextToUpdate. + * Skips sections of long matches as is necessary. + */ +-FORCE_INLINE_TEMPLATE void ZSTD_row_update_internal(ZSTD_matchState_t* ms, const BYTE* ip, +- U32 const mls, U32 const rowLog, +- U32 const rowMask, U32 const useCache) ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_row_update_internal(ZSTD_matchState_t* ms, const BYTE* ip, ++ U32 const mls, U32 const rowLog, ++ U32 const rowMask, U32 const useCache) + { + U32 idx = ms->nextToUpdate; + const BYTE* const base = ms->window.base; +@@ -971,7 +953,35 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) { + const U32 mls = MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */); + + DEBUGLOG(5, "ZSTD_row_update(), rowLog=%u", rowLog); +- ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* dont use cache */); ++ ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* don't use cache */); ++} ++ ++/* Returns the mask width of bits group of which will be set to 1. Given not all ++ * architectures have easy movemask instruction, this helps to iterate over ++ * groups of bits easier and faster. ++ */ ++FORCE_INLINE_TEMPLATE U32 ++ZSTD_row_matchMaskGroupWidth(const U32 rowEntries) ++{ ++ assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); ++ assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES); ++ (void)rowEntries; ++#if defined(ZSTD_ARCH_ARM_NEON) ++ /* NEON path only works for little endian */ ++ if (!MEM_isLittleEndian()) { ++ return 1; ++ } ++ if (rowEntries == 16) { ++ return 4; ++ } ++ if (rowEntries == 32) { ++ return 2; ++ } ++ if (rowEntries == 64) { ++ return 1; ++ } ++#endif ++ return 1; + } + + #if defined(ZSTD_ARCH_X86_SSE2) +@@ -994,71 +1004,82 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U + } + #endif + +-/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches +- * the hash at the nth position in a row of the tagTable. +- * Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield +- * to match up with the actual layout of the entries within the hashTable */ ++#if defined(ZSTD_ARCH_ARM_NEON) ++FORCE_INLINE_TEMPLATE ZSTD_VecMask ++ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag, const U32 headGrouped) ++{ ++ assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); ++ if (rowEntries == 16) { ++ /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits. ++ * After that groups of 4 bits represent the equalMask. We lower ++ * all bits except the highest in these groups by doing AND with ++ * 0x88 = 0b10001000. ++ */ ++ const uint8x16_t chunk = vld1q_u8(src); ++ const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag))); ++ const uint8x8_t res = vshrn_n_u16(equalMask, 4); ++ const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0); ++ return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull; ++ } else if (rowEntries == 32) { ++ /* Same idea as with rowEntries == 16 but doing AND with ++ * 0x55 = 0b01010101. ++ */ ++ const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src); ++ const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]); ++ const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]); ++ const uint8x16_t dup = vdupq_n_u8(tag); ++ const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6); ++ const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6); ++ const uint8x8_t res = vsli_n_u8(t0, t1, 4); ++ const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ; ++ return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull; ++ } else { /* rowEntries == 64 */ ++ const uint8x16x4_t chunk = vld4q_u8(src); ++ const uint8x16_t dup = vdupq_n_u8(tag); ++ const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup); ++ const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup); ++ const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup); ++ const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup); ++ ++ const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1); ++ const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1); ++ const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2); ++ const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4); ++ const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4); ++ const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0); ++ return ZSTD_rotateRight_U64(matches, headGrouped); ++ } ++} ++#endif ++ ++/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by ++ * ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag" ++ * matches the hash at the nth position in a row of the tagTable. ++ * Each row is a circular buffer beginning at the value of "headGrouped". So we ++ * must rotate the "matches" bitfield to match up with the actual layout of the ++ * entries within the hashTable */ + FORCE_INLINE_TEMPLATE ZSTD_VecMask +-ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries) ++ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGrouped, const U32 rowEntries) + { +- const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET; ++ const BYTE* const src = tagRow; + assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); + assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES); ++ assert(ZSTD_row_matchMaskGroupWidth(rowEntries) * rowEntries <= sizeof(ZSTD_VecMask) * 8); + + #if defined(ZSTD_ARCH_X86_SSE2) + +- return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head); ++ return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped); + + #else /* SW or NEON-LE */ + + # if defined(ZSTD_ARCH_ARM_NEON) + /* This NEON path only works for little endian - otherwise use SWAR below */ + if (MEM_isLittleEndian()) { +- if (rowEntries == 16) { +- const uint8x16_t chunk = vld1q_u8(src); +- const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag))); +- const uint16x8_t t0 = vshlq_n_u16(equalMask, 7); +- const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14)); +- const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14)); +- const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28)); +- const U16 hi = (U16)vgetq_lane_u8(t3, 8); +- const U16 lo = (U16)vgetq_lane_u8(t3, 0); +- return ZSTD_rotateRight_U16((hi << 8) | lo, head); +- } else if (rowEntries == 32) { +- const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src); +- const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]); +- const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]); +- const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag)); +- const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag)); +- const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0)); +- const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1)); +- const uint8x8_t t0 = vreinterpret_u8_s8(pack0); +- const uint8x8_t t1 = vreinterpret_u8_s8(pack1); +- const uint8x8_t t2 = vsri_n_u8(t1, t0, 2); +- const uint8x8x2_t t3 = vuzp_u8(t2, t0); +- const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4); +- const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0); +- return ZSTD_rotateRight_U32(matches, head); +- } else { /* rowEntries == 64 */ +- const uint8x16x4_t chunk = vld4q_u8(src); +- const uint8x16_t dup = vdupq_n_u8(tag); +- const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup); +- const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup); +- const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup); +- const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup); +- +- const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1); +- const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1); +- const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2); +- const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4); +- const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4); +- const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0); +- return ZSTD_rotateRight_U64(matches, head); +- } ++ return ZSTD_row_getNEONMask(rowEntries, src, tag, headGrouped); + } + # endif /* ZSTD_ARCH_ARM_NEON */ + /* SWAR */ +- { const size_t chunkSize = sizeof(size_t); ++ { const int chunkSize = sizeof(size_t); + const size_t shiftAmount = ((chunkSize * 8) - chunkSize); + const size_t xFF = ~((size_t)0); + const size_t x01 = xFF / 0xFF; +@@ -1091,11 +1112,11 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, + } + matches = ~matches; + if (rowEntries == 16) { +- return ZSTD_rotateRight_U16((U16)matches, head); ++ return ZSTD_rotateRight_U16((U16)matches, headGrouped); + } else if (rowEntries == 32) { +- return ZSTD_rotateRight_U32((U32)matches, head); ++ return ZSTD_rotateRight_U32((U32)matches, headGrouped); + } else { +- return ZSTD_rotateRight_U64((U64)matches, head); ++ return ZSTD_rotateRight_U64((U64)matches, headGrouped); + } + } + #endif +@@ -1103,20 +1124,21 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, + + /* The high-level approach of the SIMD row based match finder is as follows: + * - Figure out where to insert the new entry: +- * - Generate a hash from a byte along with an additional 1-byte "short hash". The additional byte is our "tag" +- * - The hashTable is effectively split into groups or "rows" of 16 or 32 entries of U32, and the hash determines ++ * - Generate a hash for current input posistion and split it into a one byte of tag and `rowHashLog` bits of index. ++ * - The hash is salted by a value that changes on every contex reset, so when the same table is used ++ * we will avoid collisions that would otherwise slow us down by intorducing phantom matches. ++ * - The hashTable is effectively split into groups or "rows" of 15 or 31 entries of U32, and the index determines + * which row to insert into. +- * - Determine the correct position within the row to insert the entry into. Each row of 16 or 32 can +- * be considered as a circular buffer with a "head" index that resides in the tagTable. +- * - Also insert the "tag" into the equivalent row and position in the tagTable. +- * - Note: The tagTable has 17 or 33 1-byte entries per row, due to 16 or 32 tags, and 1 "head" entry. +- * The 17 or 33 entry rows are spaced out to occur every 32 or 64 bytes, respectively, +- * for alignment/performance reasons, leaving some bytes unused. +- * - Use SIMD to efficiently compare the tags in the tagTable to the 1-byte "short hash" and ++ * - Determine the correct position within the row to insert the entry into. Each row of 15 or 31 can ++ * be considered as a circular buffer with a "head" index that resides in the tagTable (overall 16 or 32 bytes ++ * per row). ++ * - Use SIMD to efficiently compare the tags in the tagTable to the 1-byte tag calculated for the position and + * generate a bitfield that we can cycle through to check the collisions in the hash table. + * - Pick the longest match. ++ * - Insert the tag into the equivalent row and position in the tagTable. + */ + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_RowFindBestMatch( + ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iLimit, +@@ -1125,7 +1147,7 @@ size_t ZSTD_RowFindBestMatch( + const U32 rowLog) + { + U32* const hashTable = ms->hashTable; +- U16* const tagTable = ms->tagTable; ++ BYTE* const tagTable = ms->tagTable; + U32* const hashCache = ms->hashCache; + const U32 hashLog = ms->rowHashLog; + const ZSTD_compressionParameters* const cParams = &ms->cParams; +@@ -1143,8 +1165,11 @@ size_t ZSTD_RowFindBestMatch( + const U32 rowEntries = (1U << rowLog); + const U32 rowMask = rowEntries - 1; + const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */ ++ const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries); ++ const U64 hashSalt = ms->hashSalt; + U32 nbAttempts = 1U << cappedSearchLog; + size_t ml=4-1; ++ U32 hash; + + /* DMS/DDS variables that may be referenced laster */ + const ZSTD_matchState_t* const dms = ms->dictMatchState; +@@ -1168,7 +1193,7 @@ size_t ZSTD_RowFindBestMatch( + if (dictMode == ZSTD_dictMatchState) { + /* Prefetch DMS rows */ + U32* const dmsHashTable = dms->hashTable; +- U16* const dmsTagTable = dms->tagTable; ++ BYTE* const dmsTagTable = dms->tagTable; + U32 const dmsHash = (U32)ZSTD_hashPtr(ip, dms->rowHashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const dmsRelRow = (dmsHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; + dmsTag = dmsHash & ZSTD_ROW_HASH_TAG_MASK; +@@ -1178,23 +1203,34 @@ size_t ZSTD_RowFindBestMatch( + } + + /* Update the hashTable and tagTable up to (but not including) ip */ +- ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */); ++ if (!ms->lazySkipping) { ++ ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */); ++ hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls, hashSalt); ++ } else { ++ /* Stop inserting every position when in the lazy skipping mode. ++ * The hash cache is also not kept up to date in this mode. ++ */ ++ hash = (U32)ZSTD_hashPtrSalted(ip, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt); ++ ms->nextToUpdate = curr; ++ } ++ ms->hashSaltEntropy += hash; /* collect salt entropy */ ++ + { /* Get the hash for ip, compute the appropriate row */ +- U32 const hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls); + U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; + U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK; + U32* const row = hashTable + relRow; + BYTE* tagRow = (BYTE*)(tagTable + relRow); +- U32 const head = *tagRow & rowMask; ++ U32 const headGrouped = (*tagRow & rowMask) * groupWidth; + U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES]; + size_t numMatches = 0; + size_t currMatch = 0; +- ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, head, rowEntries); ++ ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, headGrouped, rowEntries); + + /* Cycle through the matches and prefetch */ +- for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) { +- U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask; ++ for (; (matches > 0) && (nbAttempts > 0); matches &= (matches - 1)) { ++ U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask; + U32 const matchIndex = row[matchPos]; ++ if(matchPos == 0) continue; + assert(numMatches < rowEntries); + if (matchIndex < lowLimit) + break; +@@ -1204,13 +1240,14 @@ size_t ZSTD_RowFindBestMatch( + PREFETCH_L1(dictBase + matchIndex); + } + matchBuffer[numMatches++] = matchIndex; ++ --nbAttempts; + } + + /* Speed opt: insert current byte into hashtable too. This allows us to avoid one iteration of the loop + in ZSTD_row_update_internal() at the next search. */ + { + U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask); +- tagRow[pos + ZSTD_ROW_HASH_TAG_OFFSET] = (BYTE)tag; ++ tagRow[pos] = (BYTE)tag; + row[pos] = ms->nextToUpdate++; + } + +@@ -1224,7 +1261,8 @@ size_t ZSTD_RowFindBestMatch( + if ((dictMode != ZSTD_extDict) || matchIndex >= dictLimit) { + const BYTE* const match = base + matchIndex; + assert(matchIndex >= dictLimit); /* ensures this is true if dictMode != ZSTD_extDict */ +- if (match[ml] == ip[ml]) /* potentially better */ ++ /* read 4B starting from (match + ml + 1 - sizeof(U32)) */ ++ if (MEM_read32(match + ml - 3) == MEM_read32(ip + ml - 3)) /* potentially better */ + currentMl = ZSTD_count(ip, match, iLimit); + } else { + const BYTE* const match = dictBase + matchIndex; +@@ -1236,7 +1274,7 @@ size_t ZSTD_RowFindBestMatch( + /* Save best solution */ + if (currentMl > ml) { + ml = currentMl; +- *offsetPtr = STORE_OFFSET(curr - matchIndex); ++ *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); + if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ + } + } +@@ -1254,19 +1292,21 @@ size_t ZSTD_RowFindBestMatch( + const U32 dmsSize = (U32)(dmsEnd - dmsBase); + const U32 dmsIndexDelta = dictLimit - dmsSize; + +- { U32 const head = *dmsTagRow & rowMask; ++ { U32 const headGrouped = (*dmsTagRow & rowMask) * groupWidth; + U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES]; + size_t numMatches = 0; + size_t currMatch = 0; +- ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, head, rowEntries); ++ ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, headGrouped, rowEntries); + +- for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) { +- U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask; ++ for (; (matches > 0) && (nbAttempts > 0); matches &= (matches - 1)) { ++ U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask; + U32 const matchIndex = dmsRow[matchPos]; ++ if(matchPos == 0) continue; + if (matchIndex < dmsLowestIndex) + break; + PREFETCH_L1(dmsBase + matchIndex); + matchBuffer[numMatches++] = matchIndex; ++ --nbAttempts; + } + + /* Return the longest match */ +@@ -1285,7 +1325,7 @@ size_t ZSTD_RowFindBestMatch( + if (currentMl > ml) { + ml = currentMl; + assert(curr > matchIndex + dmsIndexDelta); +- *offsetPtr = STORE_OFFSET(curr - (matchIndex + dmsIndexDelta)); ++ *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + dmsIndexDelta)); + if (ip+currentMl == iLimit) break; + } + } +@@ -1472,8 +1512,9 @@ FORCE_INLINE_TEMPLATE size_t ZSTD_searchMax( + * Common parser - lazy strategy + *********************************/ + +-FORCE_INLINE_TEMPLATE size_t +-ZSTD_compressBlock_lazy_generic( ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_compressBlock_lazy_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, + U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize, +@@ -1491,7 +1532,8 @@ ZSTD_compressBlock_lazy_generic( + const U32 mls = BOUNDED(4, ms->cParams.minMatch, 6); + const U32 rowLog = BOUNDED(4, ms->cParams.searchLog, 6); + +- U32 offset_1 = rep[0], offset_2 = rep[1], savedOffset=0; ++ U32 offset_1 = rep[0], offset_2 = rep[1]; ++ U32 offsetSaved1 = 0, offsetSaved2 = 0; + + const int isDMS = dictMode == ZSTD_dictMatchState; + const int isDDS = dictMode == ZSTD_dedicatedDictSearch; +@@ -1512,8 +1554,8 @@ ZSTD_compressBlock_lazy_generic( + U32 const curr = (U32)(ip - base); + U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, curr, ms->cParams.windowLog); + U32 const maxRep = curr - windowLow; +- if (offset_2 > maxRep) savedOffset = offset_2, offset_2 = 0; +- if (offset_1 > maxRep) savedOffset = offset_1, offset_1 = 0; ++ if (offset_2 > maxRep) offsetSaved2 = offset_2, offset_2 = 0; ++ if (offset_1 > maxRep) offsetSaved1 = offset_1, offset_1 = 0; + } + if (isDxS) { + /* dictMatchState repCode checks don't currently handle repCode == 0 +@@ -1522,10 +1564,11 @@ ZSTD_compressBlock_lazy_generic( + assert(offset_2 <= dictAndPrefixLength); + } + ++ /* Reset the lazy skipping state */ ++ ms->lazySkipping = 0; ++ + if (searchMethod == search_rowHash) { +- ZSTD_row_fillHashCache(ms, base, rowLog, +- MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */), +- ms->nextToUpdate, ilimit); ++ ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); + } + + /* Match Loop */ +@@ -1537,7 +1580,7 @@ ZSTD_compressBlock_lazy_generic( + #endif + while (ip < ilimit) { + size_t matchLength=0; +- size_t offcode=STORE_REPCODE_1; ++ size_t offBase = REPCODE1_TO_OFFBASE; + const BYTE* start=ip+1; + DEBUGLOG(7, "search baseline (depth 0)"); + +@@ -1562,14 +1605,23 @@ ZSTD_compressBlock_lazy_generic( + } + + /* first search (depth 0) */ +- { size_t offsetFound = 999999999; +- size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offsetFound, mls, rowLog, searchMethod, dictMode); ++ { size_t offbaseFound = 999999999; ++ size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offbaseFound, mls, rowLog, searchMethod, dictMode); + if (ml2 > matchLength) +- matchLength = ml2, start = ip, offcode=offsetFound; ++ matchLength = ml2, start = ip, offBase = offbaseFound; + } + + if (matchLength < 4) { +- ip += ((ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */ ++ size_t const step = ((size_t)(ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */; ++ ip += step; ++ /* Enter the lazy skipping mode once we are skipping more than 8 bytes at a time. ++ * In this mode we stop inserting every position into our tables, and only insert ++ * positions that we search, which is one in step positions. ++ * The exact cutoff is flexible, I've just chosen a number that is reasonably high, ++ * so we minimize the compression ratio loss in "normal" scenarios. This mode gets ++ * triggered once we've gone 2KB without finding any matches. ++ */ ++ ms->lazySkipping = step > kLazySkippingStep; + continue; + } + +@@ -1579,12 +1631,12 @@ ZSTD_compressBlock_lazy_generic( + DEBUGLOG(7, "search depth 1"); + ip ++; + if ( (dictMode == ZSTD_noDict) +- && (offcode) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { ++ && (offBase) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { + size_t const mlRep = ZSTD_count(ip+4, ip+4-offset_1, iend) + 4; + int const gain2 = (int)(mlRep * 3); +- int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); ++ int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); + if ((mlRep >= 4) && (gain2 > gain1)) +- matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; ++ matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; + } + if (isDxS) { + const U32 repIndex = (U32)(ip - base) - offset_1; +@@ -1596,17 +1648,17 @@ ZSTD_compressBlock_lazy_generic( + const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; + size_t const mlRep = ZSTD_count_2segments(ip+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; + int const gain2 = (int)(mlRep * 3); +- int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); ++ int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); + if ((mlRep >= 4) && (gain2 > gain1)) +- matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; ++ matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; + } + } +- { size_t offset2=999999999; +- size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, dictMode); +- int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 4); ++ { size_t ofbCandidate=999999999; ++ size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, dictMode); ++ int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 4); + if ((ml2 >= 4) && (gain2 > gain1)) { +- matchLength = ml2, offcode = offset2, start = ip; ++ matchLength = ml2, offBase = ofbCandidate, start = ip; + continue; /* search a better one */ + } } + +@@ -1615,12 +1667,12 @@ ZSTD_compressBlock_lazy_generic( + DEBUGLOG(7, "search depth 2"); + ip ++; + if ( (dictMode == ZSTD_noDict) +- && (offcode) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { ++ && (offBase) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { + size_t const mlRep = ZSTD_count(ip+4, ip+4-offset_1, iend) + 4; + int const gain2 = (int)(mlRep * 4); +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); + if ((mlRep >= 4) && (gain2 > gain1)) +- matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; ++ matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; + } + if (isDxS) { + const U32 repIndex = (U32)(ip - base) - offset_1; +@@ -1632,17 +1684,17 @@ ZSTD_compressBlock_lazy_generic( + const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; + size_t const mlRep = ZSTD_count_2segments(ip+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; + int const gain2 = (int)(mlRep * 4); +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); + if ((mlRep >= 4) && (gain2 > gain1)) +- matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; ++ matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; + } + } +- { size_t offset2=999999999; +- size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, dictMode); +- int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 7); ++ { size_t ofbCandidate=999999999; ++ size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, dictMode); ++ int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 7); + if ((ml2 >= 4) && (gain2 > gain1)) { +- matchLength = ml2, offcode = offset2, start = ip; ++ matchLength = ml2, offBase = ofbCandidate, start = ip; + continue; + } } } + break; /* nothing found : store previous solution */ +@@ -1653,26 +1705,33 @@ ZSTD_compressBlock_lazy_generic( + * notably if `value` is unsigned, resulting in a large positive `-value`. + */ + /* catch up */ +- if (STORED_IS_OFFSET(offcode)) { ++ if (OFFBASE_IS_OFFSET(offBase)) { + if (dictMode == ZSTD_noDict) { +- while ( ((start > anchor) & (start - STORED_OFFSET(offcode) > prefixLowest)) +- && (start[-1] == (start-STORED_OFFSET(offcode))[-1]) ) /* only search for offset within prefix */ ++ while ( ((start > anchor) & (start - OFFBASE_TO_OFFSET(offBase) > prefixLowest)) ++ && (start[-1] == (start-OFFBASE_TO_OFFSET(offBase))[-1]) ) /* only search for offset within prefix */ + { start--; matchLength++; } + } + if (isDxS) { +- U32 const matchIndex = (U32)((size_t)(start-base) - STORED_OFFSET(offcode)); ++ U32 const matchIndex = (U32)((size_t)(start-base) - OFFBASE_TO_OFFSET(offBase)); + const BYTE* match = (matchIndex < prefixLowestIndex) ? dictBase + matchIndex - dictIndexDelta : base + matchIndex; + const BYTE* const mStart = (matchIndex < prefixLowestIndex) ? dictLowest : prefixLowest; + while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */ + } +- offset_2 = offset_1; offset_1 = (U32)STORED_OFFSET(offcode); ++ offset_2 = offset_1; offset_1 = (U32)OFFBASE_TO_OFFSET(offBase); + } + /* store sequence */ + _storeSequence: + { size_t const litLength = (size_t)(start - anchor); +- ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offcode, matchLength); ++ ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offBase, matchLength); + anchor = ip = start + matchLength; + } ++ if (ms->lazySkipping) { ++ /* We've found a match, disable lazy skipping mode, and refill the hash cache. */ ++ if (searchMethod == search_rowHash) { ++ ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); ++ } ++ ms->lazySkipping = 0; ++ } + + /* check immediate repcode */ + if (isDxS) { +@@ -1686,8 +1745,8 @@ ZSTD_compressBlock_lazy_generic( + && (MEM_read32(repMatch) == MEM_read32(ip)) ) { + const BYTE* const repEnd2 = repIndex < prefixLowestIndex ? dictEnd : iend; + matchLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd2, prefixLowest) + 4; +- offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap offset_2 <=> offset_1 */ +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); ++ offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap offset_2 <=> offset_1 */ ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); + ip += matchLength; + anchor = ip; + continue; +@@ -1701,166 +1760,181 @@ ZSTD_compressBlock_lazy_generic( + && (MEM_read32(ip) == MEM_read32(ip - offset_2)) ) { + /* store sequence */ + matchLength = ZSTD_count(ip+4, ip+4-offset_2, iend) + 4; +- offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap repcodes */ +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); ++ offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap repcodes */ ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); + ip += matchLength; + anchor = ip; + continue; /* faster when present ... (?) */ + } } } + +- /* Save reps for next block */ +- rep[0] = offset_1 ? offset_1 : savedOffset; +- rep[1] = offset_2 ? offset_2 : savedOffset; ++ /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), ++ * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ ++ offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; ++ ++ /* save reps for next block */ ++ rep[0] = offset_1 ? offset_1 : offsetSaved1; ++ rep[1] = offset_2 ? offset_2 : offsetSaved2; + + /* Return the last literals size */ + return (size_t)(iend - anchor); + } ++#endif /* build exclusions */ + + +-size_t ZSTD_compressBlock_btlazy2( ++#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_greedy( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_lazy2( ++size_t ZSTD_compressBlock_greedy_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dictMatchState); + } + +-size_t ZSTD_compressBlock_lazy( ++size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dedicatedDictSearch); + } + +-size_t ZSTD_compressBlock_greedy( ++size_t ZSTD_compressBlock_greedy_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_btlazy2_dictMatchState( ++size_t ZSTD_compressBlock_greedy_dictMatchState_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dictMatchState); + } + +-size_t ZSTD_compressBlock_lazy2_dictMatchState( ++size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dedicatedDictSearch); + } ++#endif + +-size_t ZSTD_compressBlock_lazy_dictMatchState( ++#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_lazy( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_greedy_dictMatchState( ++size_t ZSTD_compressBlock_lazy_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dictMatchState); + } + +- +-size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( ++size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dedicatedDictSearch); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dedicatedDictSearch); + } + +-size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( ++size_t ZSTD_compressBlock_lazy_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dedicatedDictSearch); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( ++size_t ZSTD_compressBlock_lazy_dictMatchState_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dedicatedDictSearch); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dictMatchState); + } + +-/* Row-based matchfinder */ +-size_t ZSTD_compressBlock_lazy2_row( ++size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dedicatedDictSearch); + } ++#endif + +-size_t ZSTD_compressBlock_lazy_row( ++#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_lazy2( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_greedy_row( ++size_t ZSTD_compressBlock_lazy2_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_noDict); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dictMatchState); + } + +-size_t ZSTD_compressBlock_lazy2_dictMatchState_row( ++size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dedicatedDictSearch); + } + +-size_t ZSTD_compressBlock_lazy_dictMatchState_row( ++size_t ZSTD_compressBlock_lazy2_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_greedy_dictMatchState_row( ++size_t ZSTD_compressBlock_lazy2_dictMatchState_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dictMatchState); + } + +- + size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dedicatedDictSearch); + } ++#endif + +-size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( ++#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_btlazy2( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dedicatedDictSearch); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_noDict); + } + +-size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( ++size_t ZSTD_compressBlock_btlazy2_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + { +- return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dedicatedDictSearch); ++ return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_dictMatchState); + } ++#endif + ++#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_compressBlock_lazy_extDict_generic( + ZSTD_matchState_t* ms, seqStore_t* seqStore, + U32 rep[ZSTD_REP_NUM], +@@ -1886,12 +1960,13 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + + DEBUGLOG(5, "ZSTD_compressBlock_lazy_extDict_generic (searchFunc=%u)", (U32)searchMethod); + ++ /* Reset the lazy skipping state */ ++ ms->lazySkipping = 0; ++ + /* init */ + ip += (ip == prefixStart); + if (searchMethod == search_rowHash) { +- ZSTD_row_fillHashCache(ms, base, rowLog, +- MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */), +- ms->nextToUpdate, ilimit); ++ ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); + } + + /* Match Loop */ +@@ -1903,7 +1978,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + #endif + while (ip < ilimit) { + size_t matchLength=0; +- size_t offcode=STORE_REPCODE_1; ++ size_t offBase = REPCODE1_TO_OFFBASE; + const BYTE* start=ip+1; + U32 curr = (U32)(ip-base); + +@@ -1922,14 +1997,23 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + } } + + /* first search (depth 0) */ +- { size_t offsetFound = 999999999; +- size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offsetFound, mls, rowLog, searchMethod, ZSTD_extDict); ++ { size_t ofbCandidate = 999999999; ++ size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); + if (ml2 > matchLength) +- matchLength = ml2, start = ip, offcode=offsetFound; ++ matchLength = ml2, start = ip, offBase = ofbCandidate; + } + + if (matchLength < 4) { +- ip += ((ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */ ++ size_t const step = ((size_t)(ip-anchor) >> kSearchStrength); ++ ip += step + 1; /* jump faster over incompressible sections */ ++ /* Enter the lazy skipping mode once we are skipping more than 8 bytes at a time. ++ * In this mode we stop inserting every position into our tables, and only insert ++ * positions that we search, which is one in step positions. ++ * The exact cutoff is flexible, I've just chosen a number that is reasonably high, ++ * so we minimize the compression ratio loss in "normal" scenarios. This mode gets ++ * triggered once we've gone 2KB without finding any matches. ++ */ ++ ms->lazySkipping = step > kLazySkippingStep; + continue; + } + +@@ -1939,7 +2023,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + ip ++; + curr++; + /* check repCode */ +- if (offcode) { ++ if (offBase) { + const U32 windowLow = ZSTD_getLowestMatchIndex(ms, curr, windowLog); + const U32 repIndex = (U32)(curr - offset_1); + const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; +@@ -1951,18 +2035,18 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; + size_t const repLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; + int const gain2 = (int)(repLength * 3); +- int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); ++ int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); + if ((repLength >= 4) && (gain2 > gain1)) +- matchLength = repLength, offcode = STORE_REPCODE_1, start = ip; ++ matchLength = repLength, offBase = REPCODE1_TO_OFFBASE, start = ip; + } } + + /* search match, depth 1 */ +- { size_t offset2=999999999; +- size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, ZSTD_extDict); +- int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 4); ++ { size_t ofbCandidate = 999999999; ++ size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); ++ int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 4); + if ((ml2 >= 4) && (gain2 > gain1)) { +- matchLength = ml2, offcode = offset2, start = ip; ++ matchLength = ml2, offBase = ofbCandidate, start = ip; + continue; /* search a better one */ + } } + +@@ -1971,7 +2055,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + ip ++; + curr++; + /* check repCode */ +- if (offcode) { ++ if (offBase) { + const U32 windowLow = ZSTD_getLowestMatchIndex(ms, curr, windowLog); + const U32 repIndex = (U32)(curr - offset_1); + const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; +@@ -1983,38 +2067,45 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; + size_t const repLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; + int const gain2 = (int)(repLength * 4); +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); + if ((repLength >= 4) && (gain2 > gain1)) +- matchLength = repLength, offcode = STORE_REPCODE_1, start = ip; ++ matchLength = repLength, offBase = REPCODE1_TO_OFFBASE, start = ip; + } } + + /* search match, depth 2 */ +- { size_t offset2=999999999; +- size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, ZSTD_extDict); +- int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ +- int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 7); ++ { size_t ofbCandidate = 999999999; ++ size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); ++ int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ ++ int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 7); + if ((ml2 >= 4) && (gain2 > gain1)) { +- matchLength = ml2, offcode = offset2, start = ip; ++ matchLength = ml2, offBase = ofbCandidate, start = ip; + continue; + } } } + break; /* nothing found : store previous solution */ + } + + /* catch up */ +- if (STORED_IS_OFFSET(offcode)) { +- U32 const matchIndex = (U32)((size_t)(start-base) - STORED_OFFSET(offcode)); ++ if (OFFBASE_IS_OFFSET(offBase)) { ++ U32 const matchIndex = (U32)((size_t)(start-base) - OFFBASE_TO_OFFSET(offBase)); + const BYTE* match = (matchIndex < dictLimit) ? dictBase + matchIndex : base + matchIndex; + const BYTE* const mStart = (matchIndex < dictLimit) ? dictStart : prefixStart; + while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */ +- offset_2 = offset_1; offset_1 = (U32)STORED_OFFSET(offcode); ++ offset_2 = offset_1; offset_1 = (U32)OFFBASE_TO_OFFSET(offBase); + } + + /* store sequence */ + _storeSequence: + { size_t const litLength = (size_t)(start - anchor); +- ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offcode, matchLength); ++ ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offBase, matchLength); + anchor = ip = start + matchLength; + } ++ if (ms->lazySkipping) { ++ /* We've found a match, disable lazy skipping mode, and refill the hash cache. */ ++ if (searchMethod == search_rowHash) { ++ ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); ++ } ++ ms->lazySkipping = 0; ++ } + + /* check immediate repcode */ + while (ip <= ilimit) { +@@ -2029,8 +2120,8 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + /* repcode detected we should take it */ + const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; + matchLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; +- offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap offset history */ +- ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); ++ offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap offset history */ ++ ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); + ip += matchLength; + anchor = ip; + continue; /* faster when present ... (?) */ +@@ -2045,8 +2136,9 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( + /* Return the last literals size */ + return (size_t)(iend - anchor); + } ++#endif /* build exclusions */ + +- ++#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR + size_t ZSTD_compressBlock_greedy_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) +@@ -2054,49 +2146,55 @@ size_t ZSTD_compressBlock_greedy_extDict( + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0); + } + +-size_t ZSTD_compressBlock_lazy_extDict( ++size_t ZSTD_compressBlock_greedy_extDict_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) +- + { +- return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1); ++ return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0); + } ++#endif + +-size_t ZSTD_compressBlock_lazy2_extDict( ++#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_lazy_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + + { +- return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2); ++ return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1); + } + +-size_t ZSTD_compressBlock_btlazy2_extDict( ++size_t ZSTD_compressBlock_lazy_extDict_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + + { +- return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2); ++ return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1); + } ++#endif + +-size_t ZSTD_compressBlock_greedy_extDict_row( ++#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_lazy2_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) ++ + { +- return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0); ++ return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2); + } + +-size_t ZSTD_compressBlock_lazy_extDict_row( ++size_t ZSTD_compressBlock_lazy2_extDict_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) +- + { +- return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1); ++ return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2); + } ++#endif + +-size_t ZSTD_compressBlock_lazy2_extDict_row( ++#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_btlazy2_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize) + + { +- return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2); ++ return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2); + } ++#endif +diff --git a/lib/zstd/compress/zstd_lazy.h b/lib/zstd/compress/zstd_lazy.h +index e5bdf4df8dde..22c9201f4e63 100644 +--- a/lib/zstd/compress/zstd_lazy.h ++++ b/lib/zstd/compress/zstd_lazy.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -22,98 +23,175 @@ + */ + #define ZSTD_LAZY_DDSS_BUCKET_LOG 2 + ++#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */ ++ ++#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) + U32 ZSTD_insertAndFindFirstIndex(ZSTD_matchState_t* ms, const BYTE* ip); + void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip); + + void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_matchState_t* ms, const BYTE* const ip); + + void ZSTD_preserveUnsortedMark (U32* const table, U32 const size, U32 const reducerValue); /*! used in ZSTD_reduceIndex(). preemptively increase value of ZSTD_DUBT_UNSORTED_MARK */ ++#endif + +-size_t ZSTD_compressBlock_btlazy2( ++#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_greedy( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy2( ++size_t ZSTD_compressBlock_greedy_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy( ++size_t ZSTD_compressBlock_greedy_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy( ++size_t ZSTD_compressBlock_greedy_dictMatchState_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy2_row( ++size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_row( ++size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy_row( ++size_t ZSTD_compressBlock_greedy_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +- +-size_t ZSTD_compressBlock_btlazy2_dictMatchState( ++size_t ZSTD_compressBlock_greedy_extDict_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy2_dictMatchState( ++ ++#define ZSTD_COMPRESSBLOCK_GREEDY ZSTD_compressBlock_greedy ++#define ZSTD_COMPRESSBLOCK_GREEDY_ROW ZSTD_compressBlock_greedy_row ++#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE ZSTD_compressBlock_greedy_dictMatchState ++#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW ZSTD_compressBlock_greedy_dictMatchState_row ++#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH ZSTD_compressBlock_greedy_dedicatedDictSearch ++#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_greedy_dedicatedDictSearch_row ++#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT ZSTD_compressBlock_greedy_extDict ++#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW ZSTD_compressBlock_greedy_extDict_row ++#else ++#define ZSTD_COMPRESSBLOCK_GREEDY NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_ROW NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT NULL ++#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW NULL ++#endif ++ ++#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_lazy( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_dictMatchState( ++size_t ZSTD_compressBlock_lazy_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy_dictMatchState( ++size_t ZSTD_compressBlock_lazy_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy2_dictMatchState_row( ++size_t ZSTD_compressBlock_lazy_dictMatchState_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_dictMatchState_row( ++size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy_dictMatchState_row( ++size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +- +-size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( ++size_t ZSTD_compressBlock_lazy_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( ++size_t ZSTD_compressBlock_lazy_extDict_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( ++ ++#define ZSTD_COMPRESSBLOCK_LAZY ZSTD_compressBlock_lazy ++#define ZSTD_COMPRESSBLOCK_LAZY_ROW ZSTD_compressBlock_lazy_row ++#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE ZSTD_compressBlock_lazy_dictMatchState ++#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW ZSTD_compressBlock_lazy_dictMatchState_row ++#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH ZSTD_compressBlock_lazy_dedicatedDictSearch ++#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_lazy_dedicatedDictSearch_row ++#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT ZSTD_compressBlock_lazy_extDict ++#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW ZSTD_compressBlock_lazy_extDict_row ++#else ++#define ZSTD_COMPRESSBLOCK_LAZY NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_ROW NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT NULL ++#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW NULL ++#endif ++ ++#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_lazy2( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( ++size_t ZSTD_compressBlock_lazy2_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( ++size_t ZSTD_compressBlock_lazy2_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( ++size_t ZSTD_compressBlock_lazy2_dictMatchState_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +- +-size_t ZSTD_compressBlock_greedy_extDict( ++size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_extDict( ++size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); + size_t ZSTD_compressBlock_lazy2_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_greedy_extDict_row( ++size_t ZSTD_compressBlock_lazy2_extDict_row( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy_extDict_row( ++ ++#define ZSTD_COMPRESSBLOCK_LAZY2 ZSTD_compressBlock_lazy2 ++#define ZSTD_COMPRESSBLOCK_LAZY2_ROW ZSTD_compressBlock_lazy2_row ++#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE ZSTD_compressBlock_lazy2_dictMatchState ++#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW ZSTD_compressBlock_lazy2_dictMatchState_row ++#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH ZSTD_compressBlock_lazy2_dedicatedDictSearch ++#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_lazy2_dedicatedDictSearch_row ++#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT ZSTD_compressBlock_lazy2_extDict ++#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW ZSTD_compressBlock_lazy2_extDict_row ++#else ++#define ZSTD_COMPRESSBLOCK_LAZY2 NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_ROW NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT NULL ++#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW NULL ++#endif ++ ++#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_btlazy2( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_lazy2_extDict_row( ++size_t ZSTD_compressBlock_btlazy2_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); + size_t ZSTD_compressBlock_btlazy2_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +- ++ ++#define ZSTD_COMPRESSBLOCK_BTLAZY2 ZSTD_compressBlock_btlazy2 ++#define ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE ZSTD_compressBlock_btlazy2_dictMatchState ++#define ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT ZSTD_compressBlock_btlazy2_extDict ++#else ++#define ZSTD_COMPRESSBLOCK_BTLAZY2 NULL ++#define ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT NULL ++#endif ++ + + + #endif /* ZSTD_LAZY_H */ +diff --git a/lib/zstd/compress/zstd_ldm.c b/lib/zstd/compress/zstd_ldm.c +index dd86fc83e7dd..07f3bc6437ce 100644 +--- a/lib/zstd/compress/zstd_ldm.c ++++ b/lib/zstd/compress/zstd_ldm.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -242,11 +243,15 @@ static size_t ZSTD_ldm_fillFastTables(ZSTD_matchState_t* ms, + switch(ms->cParams.strategy) + { + case ZSTD_fast: +- ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast); ++ ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast, ZSTD_tfp_forCCtx); + break; + + case ZSTD_dfast: +- ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast); ++#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR ++ ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast, ZSTD_tfp_forCCtx); ++#else ++ assert(0); /* shouldn't be called: cparams should've been adjusted. */ ++#endif + break; + + case ZSTD_greedy: +@@ -318,7 +323,9 @@ static void ZSTD_ldm_limitTableUpdate(ZSTD_matchState_t* ms, const BYTE* anchor) + } + } + +-static size_t ZSTD_ldm_generateSequences_internal( ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_ldm_generateSequences_internal( + ldmState_t* ldmState, rawSeqStore_t* rawSeqStore, + ldmParams_t const* params, void const* src, size_t srcSize) + { +@@ -549,7 +556,7 @@ size_t ZSTD_ldm_generateSequences( + * the window through early invalidation. + * TODO: * Test the chunk size. + * * Try invalidation after the sequence generation and test the +- * the offset against maxDist directly. ++ * offset against maxDist directly. + * + * NOTE: Because of dictionaries + sequence splitting we MUST make sure + * that any offset used is valid at the END of the sequence, since it may +@@ -689,7 +696,6 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, + /* maybeSplitSequence updates rawSeqStore->pos */ + rawSeq const sequence = maybeSplitSequence(rawSeqStore, + (U32)(iend - ip), minMatch); +- int i; + /* End signal */ + if (sequence.offset == 0) + break; +@@ -702,6 +708,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, + /* Run the block compressor */ + DEBUGLOG(5, "pos %u : calling block compressor on segment of size %u", (unsigned)(ip-istart), sequence.litLength); + { ++ int i; + size_t const newLitLength = + blockCompressor(ms, seqStore, rep, ip, sequence.litLength); + ip += sequence.litLength; +@@ -711,7 +718,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, + rep[0] = sequence.offset; + /* Store the sequence */ + ZSTD_storeSeq(seqStore, newLitLength, ip - newLitLength, iend, +- STORE_OFFSET(sequence.offset), ++ OFFSET_TO_OFFBASE(sequence.offset), + sequence.matchLength); + ip += sequence.matchLength; + } +diff --git a/lib/zstd/compress/zstd_ldm.h b/lib/zstd/compress/zstd_ldm.h +index fbc6a5e88fd7..c540731abde7 100644 +--- a/lib/zstd/compress/zstd_ldm.h ++++ b/lib/zstd/compress/zstd_ldm.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/compress/zstd_ldm_geartab.h b/lib/zstd/compress/zstd_ldm_geartab.h +index 647f865be290..cfccfc46f6f7 100644 +--- a/lib/zstd/compress/zstd_ldm_geartab.h ++++ b/lib/zstd/compress/zstd_ldm_geartab.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/compress/zstd_opt.c b/lib/zstd/compress/zstd_opt.c +index fd82acfda62f..a87b66ac8d24 100644 +--- a/lib/zstd/compress/zstd_opt.c ++++ b/lib/zstd/compress/zstd_opt.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Przemyslaw Skibinski, Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -12,11 +13,14 @@ + #include "hist.h" + #include "zstd_opt.h" + ++#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) + + #define ZSTD_LITFREQ_ADD 2 /* scaling factor for litFreq, so that frequencies adapt faster to new stats */ + #define ZSTD_MAX_PRICE (1<<30) + +-#define ZSTD_PREDEF_THRESHOLD 1024 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */ ++#define ZSTD_PREDEF_THRESHOLD 8 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */ + + + /*-************************************* +@@ -26,27 +30,35 @@ + #if 0 /* approximation at bit level (for tests) */ + # define BITCOST_ACCURACY 0 + # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) +-# define WEIGHT(stat, opt) ((void)opt, ZSTD_bitWeight(stat)) ++# define WEIGHT(stat, opt) ((void)(opt), ZSTD_bitWeight(stat)) + #elif 0 /* fractional bit accuracy (for tests) */ + # define BITCOST_ACCURACY 8 + # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) +-# define WEIGHT(stat,opt) ((void)opt, ZSTD_fracWeight(stat)) ++# define WEIGHT(stat,opt) ((void)(opt), ZSTD_fracWeight(stat)) + #else /* opt==approx, ultra==accurate */ + # define BITCOST_ACCURACY 8 + # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) +-# define WEIGHT(stat,opt) (opt ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat)) ++# define WEIGHT(stat,opt) ((opt) ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat)) + #endif + ++/* ZSTD_bitWeight() : ++ * provide estimated "cost" of a stat in full bits only */ + MEM_STATIC U32 ZSTD_bitWeight(U32 stat) + { + return (ZSTD_highbit32(stat+1) * BITCOST_MULTIPLIER); + } + ++/* ZSTD_fracWeight() : ++ * provide fractional-bit "cost" of a stat, ++ * using linear interpolation approximation */ + MEM_STATIC U32 ZSTD_fracWeight(U32 rawStat) + { + U32 const stat = rawStat + 1; + U32 const hb = ZSTD_highbit32(stat); + U32 const BWeight = hb * BITCOST_MULTIPLIER; ++ /* Fweight was meant for "Fractional weight" ++ * but it's effectively a value between 1 and 2 ++ * using fixed point arithmetic */ + U32 const FWeight = (stat << BITCOST_ACCURACY) >> hb; + U32 const weight = BWeight + FWeight; + assert(hb + BITCOST_ACCURACY < 31); +@@ -57,7 +69,7 @@ MEM_STATIC U32 ZSTD_fracWeight(U32 rawStat) + /* debugging function, + * @return price in bytes as fractional value + * for debug messages only */ +-MEM_STATIC double ZSTD_fCost(U32 price) ++MEM_STATIC double ZSTD_fCost(int price) + { + return (double)price / (BITCOST_MULTIPLIER*8); + } +@@ -88,20 +100,26 @@ static U32 sum_u32(const unsigned table[], size_t nbElts) + return total; + } + +-static U32 ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift) ++typedef enum { base_0possible=0, base_1guaranteed=1 } base_directive_e; ++ ++static U32 ++ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift, base_directive_e base1) + { + U32 s, sum=0; +- DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", (unsigned)lastEltIndex+1, (unsigned)shift); ++ DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", ++ (unsigned)lastEltIndex+1, (unsigned)shift ); + assert(shift < 30); + for (s=0; s> shift); +- sum += table[s]; ++ unsigned const base = base1 ? 1 : (table[s]>0); ++ unsigned const newStat = base + (table[s] >> shift); ++ sum += newStat; ++ table[s] = newStat; + } + return sum; + } + + /* ZSTD_scaleStats() : +- * reduce all elements in table is sum too large ++ * reduce all elt frequencies in table if sum too large + * return the resulting sum of elements */ + static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget) + { +@@ -110,7 +128,7 @@ static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget) + DEBUGLOG(5, "ZSTD_scaleStats (nbElts=%u, target=%u)", (unsigned)lastEltIndex+1, (unsigned)logTarget); + assert(logTarget < 30); + if (factor <= 1) return prevsum; +- return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor)); ++ return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor), base_1guaranteed); + } + + /* ZSTD_rescaleFreqs() : +@@ -129,18 +147,22 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, + DEBUGLOG(5, "ZSTD_rescaleFreqs (srcSize=%u)", (unsigned)srcSize); + optPtr->priceType = zop_dynamic; + +- if (optPtr->litLengthSum == 0) { /* first block : init */ +- if (srcSize <= ZSTD_PREDEF_THRESHOLD) { /* heuristic */ +- DEBUGLOG(5, "(srcSize <= ZSTD_PREDEF_THRESHOLD) => zop_predef"); ++ if (optPtr->litLengthSum == 0) { /* no literals stats collected -> first block assumed -> init */ ++ ++ /* heuristic: use pre-defined stats for too small inputs */ ++ if (srcSize <= ZSTD_PREDEF_THRESHOLD) { ++ DEBUGLOG(5, "srcSize <= %i : use predefined stats", ZSTD_PREDEF_THRESHOLD); + optPtr->priceType = zop_predef; + } + + assert(optPtr->symbolCosts != NULL); + if (optPtr->symbolCosts->huf.repeatMode == HUF_repeat_valid) { +- /* huffman table presumed generated by dictionary */ ++ ++ /* huffman stats covering the full value set : table presumed generated by dictionary */ + optPtr->priceType = zop_dynamic; + + if (compressedLiterals) { ++ /* generate literals statistics from huffman table */ + unsigned lit; + assert(optPtr->litFreq != NULL); + optPtr->litSum = 0; +@@ -188,13 +210,14 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, + optPtr->offCodeSum += optPtr->offCodeFreq[of]; + } } + +- } else { /* not a dictionary */ ++ } else { /* first block, no dictionary */ + + assert(optPtr->litFreq != NULL); + if (compressedLiterals) { ++ /* base initial cost of literals on direct frequency within src */ + unsigned lit = MaxLit; + HIST_count_simple(optPtr->litFreq, &lit, src, srcSize); /* use raw first block to init statistics */ +- optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8); ++ optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8, base_0possible); + } + + { unsigned const baseLLfreqs[MaxLL+1] = { +@@ -224,10 +247,9 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, + optPtr->offCodeSum = sum_u32(baseOFCfreqs, MaxOff+1); + } + +- + } + +- } else { /* new block : re-use previous statistics, scaled down */ ++ } else { /* new block : scale down accumulated statistics */ + + if (compressedLiterals) + optPtr->litSum = ZSTD_scaleStats(optPtr->litFreq, MaxLit, 12); +@@ -246,6 +268,7 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength, + const optState_t* const optPtr, + int optLevel) + { ++ DEBUGLOG(8, "ZSTD_rawLiteralsCost (%u literals)", litLength); + if (litLength == 0) return 0; + + if (!ZSTD_compressedLiterals(optPtr)) +@@ -255,11 +278,14 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength, + return (litLength*6) * BITCOST_MULTIPLIER; /* 6 bit per literal - no statistic used */ + + /* dynamic statistics */ +- { U32 price = litLength * optPtr->litSumBasePrice; ++ { U32 price = optPtr->litSumBasePrice * litLength; ++ U32 const litPriceMax = optPtr->litSumBasePrice - BITCOST_MULTIPLIER; + U32 u; ++ assert(optPtr->litSumBasePrice >= BITCOST_MULTIPLIER); + for (u=0; u < litLength; u++) { +- assert(WEIGHT(optPtr->litFreq[literals[u]], optLevel) <= optPtr->litSumBasePrice); /* literal cost should never be negative */ +- price -= WEIGHT(optPtr->litFreq[literals[u]], optLevel); ++ U32 litPrice = WEIGHT(optPtr->litFreq[literals[u]], optLevel); ++ if (UNLIKELY(litPrice > litPriceMax)) litPrice = litPriceMax; ++ price -= litPrice; + } + return price; + } +@@ -272,10 +298,11 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP + assert(litLength <= ZSTD_BLOCKSIZE_MAX); + if (optPtr->priceType == zop_predef) + return WEIGHT(litLength, optLevel); +- /* We can't compute the litLength price for sizes >= ZSTD_BLOCKSIZE_MAX +- * because it isn't representable in the zstd format. So instead just +- * call it 1 bit more than ZSTD_BLOCKSIZE_MAX - 1. In this case the block +- * would be all literals. ++ ++ /* ZSTD_LLcode() can't compute litLength price for sizes >= ZSTD_BLOCKSIZE_MAX ++ * because it isn't representable in the zstd format. ++ * So instead just pretend it would cost 1 bit more than ZSTD_BLOCKSIZE_MAX - 1. ++ * In such a case, the block would be all literals. + */ + if (litLength == ZSTD_BLOCKSIZE_MAX) + return BITCOST_MULTIPLIER + ZSTD_litLengthPrice(ZSTD_BLOCKSIZE_MAX - 1, optPtr, optLevel); +@@ -289,24 +316,25 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP + } + + /* ZSTD_getMatchPrice() : +- * Provides the cost of the match part (offset + matchLength) of a sequence ++ * Provides the cost of the match part (offset + matchLength) of a sequence. + * Must be combined with ZSTD_fullLiteralsCost() to get the full cost of a sequence. +- * @offcode : expects a scale where 0,1,2 are repcodes 1-3, and 3+ are real_offsets+2 ++ * @offBase : sumtype, representing an offset or a repcode, and using numeric representation of ZSTD_storeSeq() + * @optLevel: when <2, favors small offset for decompression speed (improved cache efficiency) + */ + FORCE_INLINE_TEMPLATE U32 +-ZSTD_getMatchPrice(U32 const offcode, ++ZSTD_getMatchPrice(U32 const offBase, + U32 const matchLength, + const optState_t* const optPtr, + int const optLevel) + { + U32 price; +- U32 const offCode = ZSTD_highbit32(STORED_TO_OFFBASE(offcode)); ++ U32 const offCode = ZSTD_highbit32(offBase); + U32 const mlBase = matchLength - MINMATCH; + assert(matchLength >= MINMATCH); + +- if (optPtr->priceType == zop_predef) /* fixed scheme, do not use statistics */ +- return WEIGHT(mlBase, optLevel) + ((16 + offCode) * BITCOST_MULTIPLIER); ++ if (optPtr->priceType == zop_predef) /* fixed scheme, does not use statistics */ ++ return WEIGHT(mlBase, optLevel) ++ + ((16 + offCode) * BITCOST_MULTIPLIER); /* emulated offset cost */ + + /* dynamic statistics */ + price = (offCode * BITCOST_MULTIPLIER) + (optPtr->offCodeSumBasePrice - WEIGHT(optPtr->offCodeFreq[offCode], optLevel)); +@@ -325,10 +353,10 @@ ZSTD_getMatchPrice(U32 const offcode, + } + + /* ZSTD_updateStats() : +- * assumption : literals + litLengtn <= iend */ ++ * assumption : literals + litLength <= iend */ + static void ZSTD_updateStats(optState_t* const optPtr, + U32 litLength, const BYTE* literals, +- U32 offsetCode, U32 matchLength) ++ U32 offBase, U32 matchLength) + { + /* literals */ + if (ZSTD_compressedLiterals(optPtr)) { +@@ -344,8 +372,8 @@ static void ZSTD_updateStats(optState_t* const optPtr, + optPtr->litLengthSum++; + } + +- /* offset code : expected to follow storeSeq() numeric representation */ +- { U32 const offCode = ZSTD_highbit32(STORED_TO_OFFBASE(offsetCode)); ++ /* offset code : follows storeSeq() numeric representation */ ++ { U32 const offCode = ZSTD_highbit32(offBase); + assert(offCode <= MaxOff); + optPtr->offCodeFreq[offCode]++; + optPtr->offCodeSum++; +@@ -379,9 +407,11 @@ MEM_STATIC U32 ZSTD_readMINMATCH(const void* memPtr, U32 length) + + /* Update hashTable3 up to ip (excluded) + Assumption : always within prefix (i.e. not within extDict) */ +-static U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, +- U32* nextToUpdate3, +- const BYTE* const ip) ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, ++ U32* nextToUpdate3, ++ const BYTE* const ip) + { + U32* const hashTable3 = ms->hashTable3; + U32 const hashLog3 = ms->hashLog3; +@@ -408,7 +438,9 @@ static U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, + * @param ip assumed <= iend-8 . + * @param target The target of ZSTD_updateTree_internal() - we are filling to this position + * @return : nb of positions added */ +-static U32 ZSTD_insertBt1( ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_insertBt1( + const ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iend, + U32 const target, +@@ -527,6 +559,7 @@ static U32 ZSTD_insertBt1( + } + + FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + void ZSTD_updateTree_internal( + ZSTD_matchState_t* ms, + const BYTE* const ip, const BYTE* const iend, +@@ -535,7 +568,7 @@ void ZSTD_updateTree_internal( + const BYTE* const base = ms->window.base; + U32 const target = (U32)(ip - base); + U32 idx = ms->nextToUpdate; +- DEBUGLOG(6, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)", ++ DEBUGLOG(7, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)", + idx, target, dictMode); + + while(idx < target) { +@@ -553,15 +586,18 @@ void ZSTD_updateTree(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend) { + } + + FORCE_INLINE_TEMPLATE +-U32 ZSTD_insertBtAndGetAllMatches ( +- ZSTD_match_t* matches, /* store result (found matches) in this table (presumed large enough) */ +- ZSTD_matchState_t* ms, +- U32* nextToUpdate3, +- const BYTE* const ip, const BYTE* const iLimit, const ZSTD_dictMode_e dictMode, +- const U32 rep[ZSTD_REP_NUM], +- U32 const ll0, /* tells if associated literal length is 0 or not. This value must be 0 or 1 */ +- const U32 lengthToBeat, +- U32 const mls /* template */) ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ++ZSTD_insertBtAndGetAllMatches ( ++ ZSTD_match_t* matches, /* store result (found matches) in this table (presumed large enough) */ ++ ZSTD_matchState_t* ms, ++ U32* nextToUpdate3, ++ const BYTE* const ip, const BYTE* const iLimit, ++ const ZSTD_dictMode_e dictMode, ++ const U32 rep[ZSTD_REP_NUM], ++ const U32 ll0, /* tells if associated literal length is 0 or not. This value must be 0 or 1 */ ++ const U32 lengthToBeat, ++ const U32 mls /* template */) + { + const ZSTD_compressionParameters* const cParams = &ms->cParams; + U32 const sufficient_len = MIN(cParams->targetLength, ZSTD_OPT_NUM -1); +@@ -644,7 +680,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( + DEBUGLOG(8, "found repCode %u (ll0:%u, offset:%u) of length %u", + repCode, ll0, repOffset, repLen); + bestLength = repLen; +- matches[mnum].off = STORE_REPCODE(repCode - ll0 + 1); /* expect value between 1 and 3 */ ++ matches[mnum].off = REPCODE_TO_OFFBASE(repCode - ll0 + 1); /* expect value between 1 and 3 */ + matches[mnum].len = (U32)repLen; + mnum++; + if ( (repLen > sufficient_len) +@@ -673,7 +709,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( + bestLength = mlen; + assert(curr > matchIndex3); + assert(mnum==0); /* no prior solution */ +- matches[0].off = STORE_OFFSET(curr - matchIndex3); ++ matches[0].off = OFFSET_TO_OFFBASE(curr - matchIndex3); + matches[0].len = (U32)mlen; + mnum = 1; + if ( (mlen > sufficient_len) | +@@ -706,13 +742,13 @@ U32 ZSTD_insertBtAndGetAllMatches ( + } + + if (matchLength > bestLength) { +- DEBUGLOG(8, "found match of length %u at distance %u (offCode=%u)", +- (U32)matchLength, curr - matchIndex, STORE_OFFSET(curr - matchIndex)); ++ DEBUGLOG(8, "found match of length %u at distance %u (offBase=%u)", ++ (U32)matchLength, curr - matchIndex, OFFSET_TO_OFFBASE(curr - matchIndex)); + assert(matchEndIdx > matchIndex); + if (matchLength > matchEndIdx - matchIndex) + matchEndIdx = matchIndex + (U32)matchLength; + bestLength = matchLength; +- matches[mnum].off = STORE_OFFSET(curr - matchIndex); ++ matches[mnum].off = OFFSET_TO_OFFBASE(curr - matchIndex); + matches[mnum].len = (U32)matchLength; + mnum++; + if ( (matchLength > ZSTD_OPT_NUM) +@@ -754,12 +790,12 @@ U32 ZSTD_insertBtAndGetAllMatches ( + + if (matchLength > bestLength) { + matchIndex = dictMatchIndex + dmsIndexDelta; +- DEBUGLOG(8, "found dms match of length %u at distance %u (offCode=%u)", +- (U32)matchLength, curr - matchIndex, STORE_OFFSET(curr - matchIndex)); ++ DEBUGLOG(8, "found dms match of length %u at distance %u (offBase=%u)", ++ (U32)matchLength, curr - matchIndex, OFFSET_TO_OFFBASE(curr - matchIndex)); + if (matchLength > matchEndIdx - matchIndex) + matchEndIdx = matchIndex + (U32)matchLength; + bestLength = matchLength; +- matches[mnum].off = STORE_OFFSET(curr - matchIndex); ++ matches[mnum].off = OFFSET_TO_OFFBASE(curr - matchIndex); + matches[mnum].len = (U32)matchLength; + mnum++; + if ( (matchLength > ZSTD_OPT_NUM) +@@ -792,7 +828,9 @@ typedef U32 (*ZSTD_getAllMatchesFn)( + U32 const ll0, + U32 const lengthToBeat); + +-FORCE_INLINE_TEMPLATE U32 ZSTD_btGetAllMatches_internal( ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++U32 ZSTD_btGetAllMatches_internal( + ZSTD_match_t* matches, + ZSTD_matchState_t* ms, + U32* nextToUpdate3, +@@ -960,7 +998,7 @@ static void ZSTD_optLdm_maybeAddMatch(ZSTD_match_t* matches, U32* nbMatches, + const ZSTD_optLdm_t* optLdm, U32 currPosInBlock) + { + U32 const posDiff = currPosInBlock - optLdm->startPosInBlock; +- /* Note: ZSTD_match_t actually contains offCode and matchLength (before subtracting MINMATCH) */ ++ /* Note: ZSTD_match_t actually contains offBase and matchLength (before subtracting MINMATCH) */ + U32 const candidateMatchLength = optLdm->endPosInBlock - optLdm->startPosInBlock - posDiff; + + /* Ensure that current block position is not outside of the match */ +@@ -971,11 +1009,11 @@ static void ZSTD_optLdm_maybeAddMatch(ZSTD_match_t* matches, U32* nbMatches, + } + + if (*nbMatches == 0 || ((candidateMatchLength > matches[*nbMatches-1].len) && *nbMatches < ZSTD_OPT_NUM)) { +- U32 const candidateOffCode = STORE_OFFSET(optLdm->offset); +- DEBUGLOG(6, "ZSTD_optLdm_maybeAddMatch(): Adding ldm candidate match (offCode: %u matchLength %u) at block position=%u", +- candidateOffCode, candidateMatchLength, currPosInBlock); ++ U32 const candidateOffBase = OFFSET_TO_OFFBASE(optLdm->offset); ++ DEBUGLOG(6, "ZSTD_optLdm_maybeAddMatch(): Adding ldm candidate match (offBase: %u matchLength %u) at block position=%u", ++ candidateOffBase, candidateMatchLength, currPosInBlock); + matches[*nbMatches].len = candidateMatchLength; +- matches[*nbMatches].off = candidateOffCode; ++ matches[*nbMatches].off = candidateOffBase; + (*nbMatches)++; + } + } +@@ -1011,11 +1049,6 @@ ZSTD_optLdm_processMatchCandidate(ZSTD_optLdm_t* optLdm, + * Optimal parser + *********************************/ + +-static U32 ZSTD_totalLen(ZSTD_optimal_t sol) +-{ +- return sol.litlen + sol.mlen; +-} +- + #if 0 /* debug */ + + static void +@@ -1033,7 +1066,13 @@ listStats(const U32* table, int lastEltID) + + #endif + +-FORCE_INLINE_TEMPLATE size_t ++#define LIT_PRICE(_p) (int)ZSTD_rawLiteralsCost(_p, 1, optStatePtr, optLevel) ++#define LL_PRICE(_l) (int)ZSTD_litLengthPrice(_l, optStatePtr, optLevel) ++#define LL_INCPRICE(_l) (LL_PRICE(_l) - LL_PRICE(_l-1)) ++ ++FORCE_INLINE_TEMPLATE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t + ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + seqStore_t* seqStore, + U32 rep[ZSTD_REP_NUM], +@@ -1059,9 +1098,11 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + + ZSTD_optimal_t* const opt = optStatePtr->priceTable; + ZSTD_match_t* const matches = optStatePtr->matchTable; +- ZSTD_optimal_t lastSequence; ++ ZSTD_optimal_t lastStretch; + ZSTD_optLdm_t optLdm; + ++ ZSTD_memset(&lastStretch, 0, sizeof(ZSTD_optimal_t)); ++ + optLdm.seqStore = ms->ldmSeqStore ? *ms->ldmSeqStore : kNullRawSeqStore; + optLdm.endPosInBlock = optLdm.startPosInBlock = optLdm.offset = 0; + ZSTD_opt_getNextMatchAndUpdateSeqStore(&optLdm, (U32)(ip-istart), (U32)(iend-ip)); +@@ -1082,103 +1123,139 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + U32 const ll0 = !litlen; + U32 nbMatches = getAllMatches(matches, ms, &nextToUpdate3, ip, iend, rep, ll0, minMatch); + ZSTD_optLdm_processMatchCandidate(&optLdm, matches, &nbMatches, +- (U32)(ip-istart), (U32)(iend - ip)); +- if (!nbMatches) { ip++; continue; } ++ (U32)(ip-istart), (U32)(iend-ip)); ++ if (!nbMatches) { ++ DEBUGLOG(8, "no match found at cPos %u", (unsigned)(ip-istart)); ++ ip++; ++ continue; ++ } ++ ++ /* Match found: let's store this solution, and eventually find more candidates. ++ * During this forward pass, @opt is used to store stretches, ++ * defined as "a match followed by N literals". ++ * Note how this is different from a Sequence, which is "N literals followed by a match". ++ * Storing stretches allows us to store different match predecessors ++ * for each literal position part of a literals run. */ + + /* initialize opt[0] */ +- { U32 i ; for (i=0; i immediate encoding */ + { U32 const maxML = matches[nbMatches-1].len; +- U32 const maxOffcode = matches[nbMatches-1].off; +- DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffCode=%u at cPos=%u => start new series", +- nbMatches, maxML, maxOffcode, (U32)(ip-prefixStart)); ++ U32 const maxOffBase = matches[nbMatches-1].off; ++ DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffBase=%u at cPos=%u => start new series", ++ nbMatches, maxML, maxOffBase, (U32)(ip-prefixStart)); + + if (maxML > sufficient_len) { +- lastSequence.litlen = litlen; +- lastSequence.mlen = maxML; +- lastSequence.off = maxOffcode; +- DEBUGLOG(6, "large match (%u>%u), immediate encoding", ++ lastStretch.litlen = 0; ++ lastStretch.mlen = maxML; ++ lastStretch.off = maxOffBase; ++ DEBUGLOG(6, "large match (%u>%u) => immediate encoding", + maxML, sufficient_len); + cur = 0; +- last_pos = ZSTD_totalLen(lastSequence); ++ last_pos = maxML; + goto _shortestPath; + } } + + /* set prices for first matches starting position == 0 */ + assert(opt[0].price >= 0); +- { U32 const literalsPrice = (U32)opt[0].price + ZSTD_litLengthPrice(0, optStatePtr, optLevel); +- U32 pos; ++ { U32 pos; + U32 matchNb; + for (pos = 1; pos < minMatch; pos++) { +- opt[pos].price = ZSTD_MAX_PRICE; /* mlen, litlen and price will be fixed during forward scanning */ ++ opt[pos].price = ZSTD_MAX_PRICE; ++ opt[pos].mlen = 0; ++ opt[pos].litlen = litlen + pos; + } + for (matchNb = 0; matchNb < nbMatches; matchNb++) { +- U32 const offcode = matches[matchNb].off; ++ U32 const offBase = matches[matchNb].off; + U32 const end = matches[matchNb].len; + for ( ; pos <= end ; pos++ ) { +- U32 const matchPrice = ZSTD_getMatchPrice(offcode, pos, optStatePtr, optLevel); +- U32 const sequencePrice = literalsPrice + matchPrice; ++ int const matchPrice = (int)ZSTD_getMatchPrice(offBase, pos, optStatePtr, optLevel); ++ int const sequencePrice = opt[0].price + matchPrice; + DEBUGLOG(7, "rPos:%u => set initial price : %.2f", + pos, ZSTD_fCost(sequencePrice)); + opt[pos].mlen = pos; +- opt[pos].off = offcode; +- opt[pos].litlen = litlen; +- opt[pos].price = (int)sequencePrice; +- } } ++ opt[pos].off = offBase; ++ opt[pos].litlen = 0; /* end of match */ ++ opt[pos].price = sequencePrice + LL_PRICE(0); ++ } ++ } + last_pos = pos-1; ++ opt[pos].price = ZSTD_MAX_PRICE; + } + } + + /* check further positions */ + for (cur = 1; cur <= last_pos; cur++) { + const BYTE* const inr = ip + cur; +- assert(cur < ZSTD_OPT_NUM); +- DEBUGLOG(7, "cPos:%zi==rPos:%u", inr-istart, cur) ++ assert(cur <= ZSTD_OPT_NUM); ++ DEBUGLOG(7, "cPos:%zi==rPos:%u", inr-istart, cur); + + /* Fix current position with one literal if cheaper */ +- { U32 const litlen = (opt[cur-1].mlen == 0) ? opt[cur-1].litlen + 1 : 1; ++ { U32 const litlen = opt[cur-1].litlen + 1; + int const price = opt[cur-1].price +- + (int)ZSTD_rawLiteralsCost(ip+cur-1, 1, optStatePtr, optLevel) +- + (int)ZSTD_litLengthPrice(litlen, optStatePtr, optLevel) +- - (int)ZSTD_litLengthPrice(litlen-1, optStatePtr, optLevel); ++ + LIT_PRICE(ip+cur-1) ++ + LL_INCPRICE(litlen); + assert(price < 1000000000); /* overflow check */ + if (price <= opt[cur].price) { ++ ZSTD_optimal_t const prevMatch = opt[cur]; + DEBUGLOG(7, "cPos:%zi==rPos:%u : better price (%.2f<=%.2f) using literal (ll==%u) (hist:%u,%u,%u)", + inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), litlen, + opt[cur-1].rep[0], opt[cur-1].rep[1], opt[cur-1].rep[2]); +- opt[cur].mlen = 0; +- opt[cur].off = 0; ++ opt[cur] = opt[cur-1]; + opt[cur].litlen = litlen; + opt[cur].price = price; ++ if ( (optLevel >= 1) /* additional check only for higher modes */ ++ && (prevMatch.litlen == 0) /* replace a match */ ++ && (LL_INCPRICE(1) < 0) /* ll1 is cheaper than ll0 */ ++ && LIKELY(ip + cur < iend) ++ ) { ++ /* check next position, in case it would be cheaper */ ++ int with1literal = prevMatch.price + LIT_PRICE(ip+cur) + LL_INCPRICE(1); ++ int withMoreLiterals = price + LIT_PRICE(ip+cur) + LL_INCPRICE(litlen+1); ++ DEBUGLOG(7, "then at next rPos %u : match+1lit %.2f vs %ulits %.2f", ++ cur+1, ZSTD_fCost(with1literal), litlen+1, ZSTD_fCost(withMoreLiterals)); ++ if ( (with1literal < withMoreLiterals) ++ && (with1literal < opt[cur+1].price) ) { ++ /* update offset history - before it disappears */ ++ U32 const prev = cur - prevMatch.mlen; ++ repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, prevMatch.off, opt[prev].litlen==0); ++ assert(cur >= prevMatch.mlen); ++ DEBUGLOG(7, "==> match+1lit is cheaper (%.2f < %.2f) (hist:%u,%u,%u) !", ++ ZSTD_fCost(with1literal), ZSTD_fCost(withMoreLiterals), ++ newReps.rep[0], newReps.rep[1], newReps.rep[2] ); ++ opt[cur+1] = prevMatch; /* mlen & offbase */ ++ ZSTD_memcpy(opt[cur+1].rep, &newReps, sizeof(repcodes_t)); ++ opt[cur+1].litlen = 1; ++ opt[cur+1].price = with1literal; ++ if (last_pos < cur+1) last_pos = cur+1; ++ } ++ } + } else { +- DEBUGLOG(7, "cPos:%zi==rPos:%u : literal would cost more (%.2f>%.2f) (hist:%u,%u,%u)", +- inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), +- opt[cur].rep[0], opt[cur].rep[1], opt[cur].rep[2]); ++ DEBUGLOG(7, "cPos:%zi==rPos:%u : literal would cost more (%.2f>%.2f)", ++ inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price)); + } + } + +- /* Set the repcodes of the current position. We must do it here +- * because we rely on the repcodes of the 2nd to last sequence being +- * correct to set the next chunks repcodes during the backward +- * traversal. ++ /* Offset history is not updated during match comparison. ++ * Do it here, now that the match is selected and confirmed. + */ + ZSTD_STATIC_ASSERT(sizeof(opt[cur].rep) == sizeof(repcodes_t)); + assert(cur >= opt[cur].mlen); +- if (opt[cur].mlen != 0) { ++ if (opt[cur].litlen == 0) { ++ /* just finished a match => alter offset history */ + U32 const prev = cur - opt[cur].mlen; +- repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, opt[cur].off, opt[cur].litlen==0); ++ repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, opt[cur].off, opt[prev].litlen==0); + ZSTD_memcpy(opt[cur].rep, &newReps, sizeof(repcodes_t)); +- } else { +- ZSTD_memcpy(opt[cur].rep, opt[cur - 1].rep, sizeof(repcodes_t)); + } + + /* last match must start at a minimum distance of 8 from oend */ +@@ -1188,15 +1265,14 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + + if ( (optLevel==0) /*static_test*/ + && (opt[cur+1].price <= opt[cur].price + (BITCOST_MULTIPLIER/2)) ) { +- DEBUGLOG(7, "move to next rPos:%u : price is <=", cur+1); ++ DEBUGLOG(7, "skip current position : next rPos(%u) price is cheaper", cur+1); + continue; /* skip unpromising positions; about ~+6% speed, -0.01 ratio */ + } + + assert(opt[cur].price >= 0); +- { U32 const ll0 = (opt[cur].mlen != 0); +- U32 const litlen = (opt[cur].mlen == 0) ? opt[cur].litlen : 0; +- U32 const previousPrice = (U32)opt[cur].price; +- U32 const basePrice = previousPrice + ZSTD_litLengthPrice(0, optStatePtr, optLevel); ++ { U32 const ll0 = (opt[cur].litlen == 0); ++ int const previousPrice = opt[cur].price; ++ int const basePrice = previousPrice + LL_PRICE(0); + U32 nbMatches = getAllMatches(matches, ms, &nextToUpdate3, inr, iend, opt[cur].rep, ll0, minMatch); + U32 matchNb; + +@@ -1208,18 +1284,17 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + continue; + } + +- { U32 const maxML = matches[nbMatches-1].len; +- DEBUGLOG(7, "cPos:%zi==rPos:%u, found %u matches, of maxLength=%u", +- inr-istart, cur, nbMatches, maxML); +- +- if ( (maxML > sufficient_len) +- || (cur + maxML >= ZSTD_OPT_NUM) ) { +- lastSequence.mlen = maxML; +- lastSequence.off = matches[nbMatches-1].off; +- lastSequence.litlen = litlen; +- cur -= (opt[cur].mlen==0) ? opt[cur].litlen : 0; /* last sequence is actually only literals, fix cur to last match - note : may underflow, in which case, it's first sequence, and it's okay */ +- last_pos = cur + ZSTD_totalLen(lastSequence); +- if (cur > ZSTD_OPT_NUM) cur = 0; /* underflow => first match */ ++ { U32 const longestML = matches[nbMatches-1].len; ++ DEBUGLOG(7, "cPos:%zi==rPos:%u, found %u matches, of longest ML=%u", ++ inr-istart, cur, nbMatches, longestML); ++ ++ if ( (longestML > sufficient_len) ++ || (cur + longestML >= ZSTD_OPT_NUM) ++ || (ip + cur + longestML >= iend) ) { ++ lastStretch.mlen = longestML; ++ lastStretch.off = matches[nbMatches-1].off; ++ lastStretch.litlen = 0; ++ last_pos = cur + longestML; + goto _shortestPath; + } } + +@@ -1230,20 +1305,25 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + U32 const startML = (matchNb>0) ? matches[matchNb-1].len+1 : minMatch; + U32 mlen; + +- DEBUGLOG(7, "testing match %u => offCode=%4u, mlen=%2u, llen=%2u", +- matchNb, matches[matchNb].off, lastML, litlen); ++ DEBUGLOG(7, "testing match %u => offBase=%4u, mlen=%2u, llen=%2u", ++ matchNb, matches[matchNb].off, lastML, opt[cur].litlen); + + for (mlen = lastML; mlen >= startML; mlen--) { /* scan downward */ + U32 const pos = cur + mlen; +- int const price = (int)basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel); ++ int const price = basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel); + + if ((pos > last_pos) || (price < opt[pos].price)) { + DEBUGLOG(7, "rPos:%u (ml=%2u) => new better price (%.2f<%.2f)", + pos, mlen, ZSTD_fCost(price), ZSTD_fCost(opt[pos].price)); +- while (last_pos < pos) { opt[last_pos+1].price = ZSTD_MAX_PRICE; last_pos++; } /* fill empty positions */ ++ while (last_pos < pos) { ++ /* fill empty positions, for future comparisons */ ++ last_pos++; ++ opt[last_pos].price = ZSTD_MAX_PRICE; ++ opt[last_pos].litlen = !0; /* just needs to be != 0, to mean "not an end of match" */ ++ } + opt[pos].mlen = mlen; + opt[pos].off = offset; +- opt[pos].litlen = litlen; ++ opt[pos].litlen = 0; + opt[pos].price = price; + } else { + DEBUGLOG(7, "rPos:%u (ml=%2u) => new price is worse (%.2f>=%.2f)", +@@ -1251,52 +1331,86 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + if (optLevel==0) break; /* early update abort; gets ~+10% speed for about -0.01 ratio loss */ + } + } } } ++ opt[last_pos+1].price = ZSTD_MAX_PRICE; + } /* for (cur = 1; cur <= last_pos; cur++) */ + +- lastSequence = opt[last_pos]; +- cur = last_pos > ZSTD_totalLen(lastSequence) ? last_pos - ZSTD_totalLen(lastSequence) : 0; /* single sequence, and it starts before `ip` */ +- assert(cur < ZSTD_OPT_NUM); /* control overflow*/ ++ lastStretch = opt[last_pos]; ++ assert(cur >= lastStretch.mlen); ++ cur = last_pos - lastStretch.mlen; + + _shortestPath: /* cur, last_pos, best_mlen, best_off have to be set */ + assert(opt[0].mlen == 0); ++ assert(last_pos >= lastStretch.mlen); ++ assert(cur == last_pos - lastStretch.mlen); + +- /* Set the next chunk's repcodes based on the repcodes of the beginning +- * of the last match, and the last sequence. This avoids us having to +- * update them while traversing the sequences. +- */ +- if (lastSequence.mlen != 0) { +- repcodes_t const reps = ZSTD_newRep(opt[cur].rep, lastSequence.off, lastSequence.litlen==0); +- ZSTD_memcpy(rep, &reps, sizeof(reps)); ++ if (lastStretch.mlen==0) { ++ /* no solution : all matches have been converted into literals */ ++ assert(lastStretch.litlen == (ip - anchor) + last_pos); ++ ip += last_pos; ++ continue; ++ } ++ assert(lastStretch.off > 0); ++ ++ /* Update offset history */ ++ if (lastStretch.litlen == 0) { ++ /* finishing on a match : update offset history */ ++ repcodes_t const reps = ZSTD_newRep(opt[cur].rep, lastStretch.off, opt[cur].litlen==0); ++ ZSTD_memcpy(rep, &reps, sizeof(repcodes_t)); + } else { +- ZSTD_memcpy(rep, opt[cur].rep, sizeof(repcodes_t)); ++ ZSTD_memcpy(rep, lastStretch.rep, sizeof(repcodes_t)); ++ assert(cur >= lastStretch.litlen); ++ cur -= lastStretch.litlen; + } + +- { U32 const storeEnd = cur + 1; ++ /* Let's write the shortest path solution. ++ * It is stored in @opt in reverse order, ++ * starting from @storeEnd (==cur+2), ++ * effectively partially @opt overwriting. ++ * Content is changed too: ++ * - So far, @opt stored stretches, aka a match followed by literals ++ * - Now, it will store sequences, aka literals followed by a match ++ */ ++ { U32 const storeEnd = cur + 2; + U32 storeStart = storeEnd; +- U32 seqPos = cur; ++ U32 stretchPos = cur; + + DEBUGLOG(6, "start reverse traversal (last_pos:%u, cur:%u)", + last_pos, cur); (void)last_pos; +- assert(storeEnd < ZSTD_OPT_NUM); +- DEBUGLOG(6, "last sequence copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", +- storeEnd, lastSequence.litlen, lastSequence.mlen, lastSequence.off); +- opt[storeEnd] = lastSequence; +- while (seqPos > 0) { +- U32 const backDist = ZSTD_totalLen(opt[seqPos]); ++ assert(storeEnd < ZSTD_OPT_SIZE); ++ DEBUGLOG(6, "last stretch copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", ++ storeEnd, lastStretch.litlen, lastStretch.mlen, lastStretch.off); ++ if (lastStretch.litlen > 0) { ++ /* last "sequence" is unfinished: just a bunch of literals */ ++ opt[storeEnd].litlen = lastStretch.litlen; ++ opt[storeEnd].mlen = 0; ++ storeStart = storeEnd-1; ++ opt[storeStart] = lastStretch; ++ } { ++ opt[storeEnd] = lastStretch; /* note: litlen will be fixed */ ++ storeStart = storeEnd; ++ } ++ while (1) { ++ ZSTD_optimal_t nextStretch = opt[stretchPos]; ++ opt[storeStart].litlen = nextStretch.litlen; ++ DEBUGLOG(6, "selected sequence (llen=%u,mlen=%u,ofc=%u)", ++ opt[storeStart].litlen, opt[storeStart].mlen, opt[storeStart].off); ++ if (nextStretch.mlen == 0) { ++ /* reaching beginning of segment */ ++ break; ++ } + storeStart--; +- DEBUGLOG(6, "sequence from rPos=%u copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", +- seqPos, storeStart, opt[seqPos].litlen, opt[seqPos].mlen, opt[seqPos].off); +- opt[storeStart] = opt[seqPos]; +- seqPos = (seqPos > backDist) ? seqPos - backDist : 0; ++ opt[storeStart] = nextStretch; /* note: litlen will be fixed */ ++ assert(nextStretch.litlen + nextStretch.mlen <= stretchPos); ++ stretchPos -= nextStretch.litlen + nextStretch.mlen; + } + + /* save sequences */ +- DEBUGLOG(6, "sending selected sequences into seqStore") ++ DEBUGLOG(6, "sending selected sequences into seqStore"); + { U32 storePos; + for (storePos=storeStart; storePos <= storeEnd; storePos++) { + U32 const llen = opt[storePos].litlen; + U32 const mlen = opt[storePos].mlen; +- U32 const offCode = opt[storePos].off; ++ U32 const offBase = opt[storePos].off; + U32 const advance = llen + mlen; + DEBUGLOG(6, "considering seq starting at %zi, llen=%u, mlen=%u", + anchor - istart, (unsigned)llen, (unsigned)mlen); +@@ -1308,11 +1422,14 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + } + + assert(anchor + llen <= iend); +- ZSTD_updateStats(optStatePtr, llen, anchor, offCode, mlen); +- ZSTD_storeSeq(seqStore, llen, anchor, iend, offCode, mlen); ++ ZSTD_updateStats(optStatePtr, llen, anchor, offBase, mlen); ++ ZSTD_storeSeq(seqStore, llen, anchor, iend, offBase, mlen); + anchor += advance; + ip = anchor; + } } ++ DEBUGLOG(7, "new offset history : %u, %u, %u", rep[0], rep[1], rep[2]); ++ ++ /* update all costs */ + ZSTD_setBasePrices(optStatePtr, optLevel); + } + } /* while (ip < ilimit) */ +@@ -1320,21 +1437,27 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, + /* Return the last literals size */ + return (size_t)(iend - anchor); + } ++#endif /* build exclusions */ + ++#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR + static size_t ZSTD_compressBlock_opt0( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize, const ZSTD_dictMode_e dictMode) + { + return ZSTD_compressBlock_opt_generic(ms, seqStore, rep, src, srcSize, 0 /* optLevel */, dictMode); + } ++#endif + ++#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR + static size_t ZSTD_compressBlock_opt2( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize, const ZSTD_dictMode_e dictMode) + { + return ZSTD_compressBlock_opt_generic(ms, seqStore, rep, src, srcSize, 2 /* optLevel */, dictMode); + } ++#endif + ++#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR + size_t ZSTD_compressBlock_btopt( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize) +@@ -1342,20 +1465,23 @@ size_t ZSTD_compressBlock_btopt( + DEBUGLOG(5, "ZSTD_compressBlock_btopt"); + return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_noDict); + } ++#endif + + + + ++#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR + /* ZSTD_initStats_ultra(): + * make a first compression pass, just to seed stats with more accurate starting values. + * only works on first block, with no dictionary and no ldm. +- * this function cannot error, hence its contract must be respected. ++ * this function cannot error out, its narrow contract must be respected. + */ +-static void +-ZSTD_initStats_ultra(ZSTD_matchState_t* ms, +- seqStore_t* seqStore, +- U32 rep[ZSTD_REP_NUM], +- const void* src, size_t srcSize) ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++void ZSTD_initStats_ultra(ZSTD_matchState_t* ms, ++ seqStore_t* seqStore, ++ U32 rep[ZSTD_REP_NUM], ++ const void* src, size_t srcSize) + { + U32 tmpRep[ZSTD_REP_NUM]; /* updated rep codes will sink here */ + ZSTD_memcpy(tmpRep, rep, sizeof(tmpRep)); +@@ -1368,7 +1494,7 @@ ZSTD_initStats_ultra(ZSTD_matchState_t* ms, + + ZSTD_compressBlock_opt2(ms, seqStore, tmpRep, src, srcSize, ZSTD_noDict); /* generate stats into ms->opt*/ + +- /* invalidate first scan from history */ ++ /* invalidate first scan from history, only keep entropy stats */ + ZSTD_resetSeqStore(seqStore); + ms->window.base -= srcSize; + ms->window.dictLimit += (U32)srcSize; +@@ -1392,10 +1518,10 @@ size_t ZSTD_compressBlock_btultra2( + U32 const curr = (U32)((const BYTE*)src - ms->window.base); + DEBUGLOG(5, "ZSTD_compressBlock_btultra2 (srcSize=%zu)", srcSize); + +- /* 2-pass strategy: ++ /* 2-passes strategy: + * this strategy makes a first pass over first block to collect statistics +- * and seed next round's statistics with it. +- * After 1st pass, function forgets everything, and starts a new block. ++ * in order to seed next round's statistics with it. ++ * After 1st pass, function forgets history, and starts a new block. + * Consequently, this can only work if no data has been previously loaded in tables, + * aka, no dictionary, no prefix, no ldm preprocessing. + * The compression ratio gain is generally small (~0.5% on first block), +@@ -1404,15 +1530,17 @@ size_t ZSTD_compressBlock_btultra2( + if ( (ms->opt.litLengthSum==0) /* first block */ + && (seqStore->sequences == seqStore->sequencesStart) /* no ldm */ + && (ms->window.dictLimit == ms->window.lowLimit) /* no dictionary */ +- && (curr == ms->window.dictLimit) /* start of frame, nothing already loaded nor skipped */ +- && (srcSize > ZSTD_PREDEF_THRESHOLD) ++ && (curr == ms->window.dictLimit) /* start of frame, nothing already loaded nor skipped */ ++ && (srcSize > ZSTD_PREDEF_THRESHOLD) /* input large enough to not employ default stats */ + ) { + ZSTD_initStats_ultra(ms, seqStore, rep, src, srcSize); + } + + return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_noDict); + } ++#endif + ++#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR + size_t ZSTD_compressBlock_btopt_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize) +@@ -1420,18 +1548,20 @@ size_t ZSTD_compressBlock_btopt_dictMatchState( + return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); + } + +-size_t ZSTD_compressBlock_btultra_dictMatchState( ++size_t ZSTD_compressBlock_btopt_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize) + { +- return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); ++ return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_extDict); + } ++#endif + +-size_t ZSTD_compressBlock_btopt_extDict( ++#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_btultra_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize) + { +- return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_extDict); ++ return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); + } + + size_t ZSTD_compressBlock_btultra_extDict( +@@ -1440,6 +1570,7 @@ size_t ZSTD_compressBlock_btultra_extDict( + { + return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_extDict); + } ++#endif + + /* note : no btultra2 variant for extDict nor dictMatchState, + * because btultra2 is not meant to work with dictionaries +diff --git a/lib/zstd/compress/zstd_opt.h b/lib/zstd/compress/zstd_opt.h +index 22b862858ba7..ac1b743d27cd 100644 +--- a/lib/zstd/compress/zstd_opt.h ++++ b/lib/zstd/compress/zstd_opt.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -14,30 +15,40 @@ + + #include "zstd_compress_internal.h" + ++#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ ++ || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) + /* used in ZSTD_loadDictionaryContent() */ + void ZSTD_updateTree(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend); ++#endif + ++#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR + size_t ZSTD_compressBlock_btopt( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_btultra( ++size_t ZSTD_compressBlock_btopt_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +-size_t ZSTD_compressBlock_btultra2( ++size_t ZSTD_compressBlock_btopt_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); + ++#define ZSTD_COMPRESSBLOCK_BTOPT ZSTD_compressBlock_btopt ++#define ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE ZSTD_compressBlock_btopt_dictMatchState ++#define ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT ZSTD_compressBlock_btopt_extDict ++#else ++#define ZSTD_COMPRESSBLOCK_BTOPT NULL ++#define ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT NULL ++#endif + +-size_t ZSTD_compressBlock_btopt_dictMatchState( ++#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR ++size_t ZSTD_compressBlock_btultra( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); + size_t ZSTD_compressBlock_btultra_dictMatchState( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +- +-size_t ZSTD_compressBlock_btopt_extDict( +- ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +- void const* src, size_t srcSize); + size_t ZSTD_compressBlock_btultra_extDict( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +@@ -45,6 +56,20 @@ size_t ZSTD_compressBlock_btultra_extDict( + /* note : no btultra2 variant for extDict nor dictMatchState, + * because btultra2 is not meant to work with dictionaries + * and is only specific for the first block (no prefix) */ ++size_t ZSTD_compressBlock_btultra2( ++ ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], ++ void const* src, size_t srcSize); ++ ++#define ZSTD_COMPRESSBLOCK_BTULTRA ZSTD_compressBlock_btultra ++#define ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE ZSTD_compressBlock_btultra_dictMatchState ++#define ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT ZSTD_compressBlock_btultra_extDict ++#define ZSTD_COMPRESSBLOCK_BTULTRA2 ZSTD_compressBlock_btultra2 ++#else ++#define ZSTD_COMPRESSBLOCK_BTULTRA NULL ++#define ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE NULL ++#define ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT NULL ++#define ZSTD_COMPRESSBLOCK_BTULTRA2 NULL ++#endif + + + #endif /* ZSTD_OPT_H */ +diff --git a/lib/zstd/decompress/huf_decompress.c b/lib/zstd/decompress/huf_decompress.c +index 60958afebc41..ac8b87f48f84 100644 +--- a/lib/zstd/decompress/huf_decompress.c ++++ b/lib/zstd/decompress/huf_decompress.c +@@ -1,7 +1,8 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* ****************************************************************** + * huff0 huffman decoder, + * part of Finite State Entropy library +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * You can contact the author at : + * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy +@@ -19,10 +20,10 @@ + #include "../common/compiler.h" + #include "../common/bitstream.h" /* BIT_* */ + #include "../common/fse.h" /* to compress headers */ +-#define HUF_STATIC_LINKING_ONLY + #include "../common/huf.h" + #include "../common/error_private.h" + #include "../common/zstd_internal.h" ++#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_countTrailingZeros64 */ + + /* ************************************************************** + * Constants +@@ -34,6 +35,12 @@ + * Macros + ****************************************************************/ + ++#ifdef HUF_DISABLE_FAST_DECODE ++# define HUF_ENABLE_FAST_DECODE 0 ++#else ++# define HUF_ENABLE_FAST_DECODE 1 ++#endif ++ + /* These two optional macros force the use one way or another of the two + * Huffman decompression implementations. You can't force in both directions + * at the same time. +@@ -43,27 +50,25 @@ + #error "Cannot force the use of the X1 and X2 decoders at the same time!" + #endif + +-#if ZSTD_ENABLE_ASM_X86_64_BMI2 && DYNAMIC_BMI2 +-# define HUF_ASM_X86_64_BMI2_ATTRS BMI2_TARGET_ATTRIBUTE ++/* When DYNAMIC_BMI2 is enabled, fast decoders are only called when bmi2 is ++ * supported at runtime, so we can add the BMI2 target attribute. ++ * When it is disabled, we will still get BMI2 if it is enabled statically. ++ */ ++#if DYNAMIC_BMI2 ++# define HUF_FAST_BMI2_ATTRS BMI2_TARGET_ATTRIBUTE + #else +-# define HUF_ASM_X86_64_BMI2_ATTRS ++# define HUF_FAST_BMI2_ATTRS + #endif + + #define HUF_EXTERN_C + #define HUF_ASM_DECL HUF_EXTERN_C + +-#if DYNAMIC_BMI2 || (ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__)) ++#if DYNAMIC_BMI2 + # define HUF_NEED_BMI2_FUNCTION 1 + #else + # define HUF_NEED_BMI2_FUNCTION 0 + #endif + +-#if !(ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__)) +-# define HUF_NEED_DEFAULT_FUNCTION 1 +-#else +-# define HUF_NEED_DEFAULT_FUNCTION 0 +-#endif +- + /* ************************************************************** + * Error Management + ****************************************************************/ +@@ -80,6 +85,11 @@ + /* ************************************************************** + * BMI2 Variant Wrappers + ****************************************************************/ ++typedef size_t (*HUF_DecompressUsingDTableFn)(void *dst, size_t dstSize, ++ const void *cSrc, ++ size_t cSrcSize, ++ const HUF_DTable *DTable); ++ + #if DYNAMIC_BMI2 + + #define HUF_DGEN(fn) \ +@@ -101,9 +111,9 @@ + } \ + \ + static size_t fn(void* dst, size_t dstSize, void const* cSrc, \ +- size_t cSrcSize, HUF_DTable const* DTable, int bmi2) \ ++ size_t cSrcSize, HUF_DTable const* DTable, int flags) \ + { \ +- if (bmi2) { \ ++ if (flags & HUF_flags_bmi2) { \ + return fn##_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); \ + } \ + return fn##_default(dst, dstSize, cSrc, cSrcSize, DTable); \ +@@ -113,9 +123,9 @@ + + #define HUF_DGEN(fn) \ + static size_t fn(void* dst, size_t dstSize, void const* cSrc, \ +- size_t cSrcSize, HUF_DTable const* DTable, int bmi2) \ ++ size_t cSrcSize, HUF_DTable const* DTable, int flags) \ + { \ +- (void)bmi2; \ ++ (void)flags; \ + return fn##_body(dst, dstSize, cSrc, cSrcSize, DTable); \ + } + +@@ -134,43 +144,66 @@ static DTableDesc HUF_getDTableDesc(const HUF_DTable* table) + return dtd; + } + +-#if ZSTD_ENABLE_ASM_X86_64_BMI2 +- +-static size_t HUF_initDStream(BYTE const* ip) { ++static size_t HUF_initFastDStream(BYTE const* ip) { + BYTE const lastByte = ip[7]; +- size_t const bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; ++ size_t const bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; + size_t const value = MEM_readLEST(ip) | 1; + assert(bitsConsumed <= 8); ++ assert(sizeof(size_t) == 8); + return value << bitsConsumed; + } ++ ++ ++/* ++ * The input/output arguments to the Huffman fast decoding loop: ++ * ++ * ip [in/out] - The input pointers, must be updated to reflect what is consumed. ++ * op [in/out] - The output pointers, must be updated to reflect what is written. ++ * bits [in/out] - The bitstream containers, must be updated to reflect the current state. ++ * dt [in] - The decoding table. ++ * ilowest [in] - The beginning of the valid range of the input. Decoders may read ++ * down to this pointer. It may be below iend[0]. ++ * oend [in] - The end of the output stream. op[3] must not cross oend. ++ * iend [in] - The end of each input stream. ip[i] may cross iend[i], ++ * as long as it is above ilowest, but that indicates corruption. ++ */ + typedef struct { + BYTE const* ip[4]; + BYTE* op[4]; + U64 bits[4]; + void const* dt; +- BYTE const* ilimit; ++ BYTE const* ilowest; + BYTE* oend; + BYTE const* iend[4]; +-} HUF_DecompressAsmArgs; ++} HUF_DecompressFastArgs; ++ ++typedef void (*HUF_DecompressFastLoopFn)(HUF_DecompressFastArgs*); + + /* +- * Initializes args for the asm decoding loop. +- * @returns 0 on success +- * 1 if the fallback implementation should be used. ++ * Initializes args for the fast decoding loop. ++ * @returns 1 on success ++ * 0 if the fallback implementation should be used. + * Or an error code on failure. + */ +-static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, size_t dstSize, void const* src, size_t srcSize, const HUF_DTable* DTable) ++static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* dst, size_t dstSize, void const* src, size_t srcSize, const HUF_DTable* DTable) + { + void const* dt = DTable + 1; + U32 const dtLog = HUF_getDTableDesc(DTable).tableLog; + +- const BYTE* const ilimit = (const BYTE*)src + 6 + 8; ++ const BYTE* const istart = (const BYTE*)src; + +- BYTE* const oend = (BYTE*)dst + dstSize; ++ BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); + +- /* The following condition is false on x32 platform, +- * but HUF_asm is not compatible with this ABI */ +- if (!(MEM_isLittleEndian() && !MEM_32bits())) return 1; ++ /* The fast decoding loop assumes 64-bit little-endian. ++ * This condition is false on x32. ++ */ ++ if (!MEM_isLittleEndian() || MEM_32bits()) ++ return 0; ++ ++ /* Avoid nullptr addition */ ++ if (dstSize == 0) ++ return 0; ++ assert(dst != NULL); + + /* strict minimum : jump table + 1 byte per stream */ + if (srcSize < 10) +@@ -181,11 +214,10 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, + * On small inputs we don't have enough data to trigger the fast loop, so use the old decoder. + */ + if (dtLog != HUF_DECODER_FAST_TABLELOG) +- return 1; ++ return 0; + + /* Read the jump table. */ + { +- const BYTE* const istart = (const BYTE*)src; + size_t const length1 = MEM_readLE16(istart); + size_t const length2 = MEM_readLE16(istart+2); + size_t const length3 = MEM_readLE16(istart+4); +@@ -195,13 +227,11 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, + args->iend[2] = args->iend[1] + length2; + args->iend[3] = args->iend[2] + length3; + +- /* HUF_initDStream() requires this, and this small of an input ++ /* HUF_initFastDStream() requires this, and this small of an input + * won't benefit from the ASM loop anyways. +- * length1 must be >= 16 so that ip[0] >= ilimit before the loop +- * starts. + */ +- if (length1 < 16 || length2 < 8 || length3 < 8 || length4 < 8) +- return 1; ++ if (length1 < 8 || length2 < 8 || length3 < 8 || length4 < 8) ++ return 0; + if (length4 > srcSize) return ERROR(corruption_detected); /* overflow */ + } + /* ip[] contains the position that is currently loaded into bits[]. */ +@@ -218,7 +248,7 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, + + /* No point to call the ASM loop for tiny outputs. */ + if (args->op[3] >= oend) +- return 1; ++ return 0; + + /* bits[] is the bit container. + * It is read from the MSB down to the LSB. +@@ -227,24 +257,25 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, + * set, so that CountTrailingZeros(bits[]) can be used + * to count how many bits we've consumed. + */ +- args->bits[0] = HUF_initDStream(args->ip[0]); +- args->bits[1] = HUF_initDStream(args->ip[1]); +- args->bits[2] = HUF_initDStream(args->ip[2]); +- args->bits[3] = HUF_initDStream(args->ip[3]); +- +- /* If ip[] >= ilimit, it is guaranteed to be safe to +- * reload bits[]. It may be beyond its section, but is +- * guaranteed to be valid (>= istart). +- */ +- args->ilimit = ilimit; ++ args->bits[0] = HUF_initFastDStream(args->ip[0]); ++ args->bits[1] = HUF_initFastDStream(args->ip[1]); ++ args->bits[2] = HUF_initFastDStream(args->ip[2]); ++ args->bits[3] = HUF_initFastDStream(args->ip[3]); ++ ++ /* The decoders must be sure to never read beyond ilowest. ++ * This is lower than iend[0], but allowing decoders to read ++ * down to ilowest can allow an extra iteration or two in the ++ * fast loop. ++ */ ++ args->ilowest = istart; + + args->oend = oend; + args->dt = dt; + +- return 0; ++ return 1; + } + +-static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressAsmArgs const* args, int stream, BYTE* segmentEnd) ++static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressFastArgs const* args, int stream, BYTE* segmentEnd) + { + /* Validate that we haven't overwritten. */ + if (args->op[stream] > segmentEnd) +@@ -258,15 +289,33 @@ static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressAsmArgs + return ERROR(corruption_detected); + + /* Construct the BIT_DStream_t. */ +- bit->bitContainer = MEM_readLE64(args->ip[stream]); +- bit->bitsConsumed = ZSTD_countTrailingZeros((size_t)args->bits[stream]); +- bit->start = (const char*)args->iend[0]; ++ assert(sizeof(size_t) == 8); ++ bit->bitContainer = MEM_readLEST(args->ip[stream]); ++ bit->bitsConsumed = ZSTD_countTrailingZeros64(args->bits[stream]); ++ bit->start = (const char*)args->ilowest; + bit->limitPtr = bit->start + sizeof(size_t); + bit->ptr = (const char*)args->ip[stream]; + + return 0; + } +-#endif ++ ++/* Calls X(N) for each stream 0, 1, 2, 3. */ ++#define HUF_4X_FOR_EACH_STREAM(X) \ ++ do { \ ++ X(0); \ ++ X(1); \ ++ X(2); \ ++ X(3); \ ++ } while (0) ++ ++/* Calls X(N, var) for each stream 0, 1, 2, 3. */ ++#define HUF_4X_FOR_EACH_STREAM_WITH_VAR(X, var) \ ++ do { \ ++ X(0, (var)); \ ++ X(1, (var)); \ ++ X(2, (var)); \ ++ X(3, (var)); \ ++ } while (0) + + + #ifndef HUF_FORCE_DECOMPRESS_X2 +@@ -283,10 +332,11 @@ typedef struct { BYTE nbBits; BYTE byte; } HUF_DEltX1; /* single-symbol decodi + static U64 HUF_DEltX1_set4(BYTE symbol, BYTE nbBits) { + U64 D4; + if (MEM_isLittleEndian()) { +- D4 = (symbol << 8) + nbBits; ++ D4 = (U64)((symbol << 8) + nbBits); + } else { +- D4 = symbol + (nbBits << 8); ++ D4 = (U64)(symbol + (nbBits << 8)); + } ++ assert(D4 < (1U << 16)); + D4 *= 0x0001000100010001ULL; + return D4; + } +@@ -329,13 +379,7 @@ typedef struct { + BYTE huffWeight[HUF_SYMBOLVALUE_MAX + 1]; + } HUF_ReadDTableX1_Workspace; + +- +-size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize) +-{ +- return HUF_readDTableX1_wksp_bmi2(DTable, src, srcSize, workSpace, wkspSize, /* bmi2 */ 0); +-} +- +-size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2) ++size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags) + { + U32 tableLog = 0; + U32 nbSymbols = 0; +@@ -350,7 +394,7 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr + DEBUG_STATIC_ASSERT(sizeof(DTableDesc) == sizeof(HUF_DTable)); + /* ZSTD_memset(huffWeight, 0, sizeof(huffWeight)); */ /* is not necessary, even though some analyzer complain ... */ + +- iSize = HUF_readStats_wksp(wksp->huffWeight, HUF_SYMBOLVALUE_MAX + 1, wksp->rankVal, &nbSymbols, &tableLog, src, srcSize, wksp->statsWksp, sizeof(wksp->statsWksp), bmi2); ++ iSize = HUF_readStats_wksp(wksp->huffWeight, HUF_SYMBOLVALUE_MAX + 1, wksp->rankVal, &nbSymbols, &tableLog, src, srcSize, wksp->statsWksp, sizeof(wksp->statsWksp), flags); + if (HUF_isError(iSize)) return iSize; + + +@@ -377,9 +421,8 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr + * rankStart[0] is not filled because there are no entries in the table for + * weight 0. + */ +- { +- int n; +- int nextRankStart = 0; ++ { int n; ++ U32 nextRankStart = 0; + int const unroll = 4; + int const nLimit = (int)nbSymbols - unroll + 1; + for (n=0; n<(int)tableLog+1; n++) { +@@ -406,10 +449,9 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr + * We can switch based on the length to a different inner loop which is + * optimized for that particular case. + */ +- { +- U32 w; +- int symbol=wksp->rankVal[0]; +- int rankStart=0; ++ { U32 w; ++ int symbol = wksp->rankVal[0]; ++ int rankStart = 0; + for (w=1; wrankVal[w]; + int const length = (1 << w) >> 1; +@@ -483,15 +525,19 @@ HUF_decodeSymbolX1(BIT_DStream_t* Dstream, const HUF_DEltX1* dt, const U32 dtLog + } + + #define HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) \ +- *ptr++ = HUF_decodeSymbolX1(DStreamPtr, dt, dtLog) ++ do { *ptr++ = HUF_decodeSymbolX1(DStreamPtr, dt, dtLog); } while (0) + +-#define HUF_DECODE_SYMBOLX1_1(ptr, DStreamPtr) \ +- if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ +- HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) ++#define HUF_DECODE_SYMBOLX1_1(ptr, DStreamPtr) \ ++ do { \ ++ if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ ++ HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr); \ ++ } while (0) + +-#define HUF_DECODE_SYMBOLX1_2(ptr, DStreamPtr) \ +- if (MEM_64bits()) \ +- HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) ++#define HUF_DECODE_SYMBOLX1_2(ptr, DStreamPtr) \ ++ do { \ ++ if (MEM_64bits()) \ ++ HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr); \ ++ } while (0) + + HINT_INLINE size_t + HUF_decodeStreamX1(BYTE* p, BIT_DStream_t* const bitDPtr, BYTE* const pEnd, const HUF_DEltX1* const dt, const U32 dtLog) +@@ -519,7 +565,7 @@ HUF_decodeStreamX1(BYTE* p, BIT_DStream_t* const bitDPtr, BYTE* const pEnd, cons + while (p < pEnd) + HUF_DECODE_SYMBOLX1_0(p, bitDPtr); + +- return pEnd-pStart; ++ return (size_t)(pEnd-pStart); + } + + FORCE_INLINE_TEMPLATE size_t +@@ -529,7 +575,7 @@ HUF_decompress1X1_usingDTable_internal_body( + const HUF_DTable* DTable) + { + BYTE* op = (BYTE*)dst; +- BYTE* const oend = op + dstSize; ++ BYTE* const oend = ZSTD_maybeNullPtrAdd(op, dstSize); + const void* dtPtr = DTable + 1; + const HUF_DEltX1* const dt = (const HUF_DEltX1*)dtPtr; + BIT_DStream_t bitD; +@@ -545,6 +591,10 @@ HUF_decompress1X1_usingDTable_internal_body( + return dstSize; + } + ++/* HUF_decompress4X1_usingDTable_internal_body(): ++ * Conditions : ++ * @dstSize >= 6 ++ */ + FORCE_INLINE_TEMPLATE size_t + HUF_decompress4X1_usingDTable_internal_body( + void* dst, size_t dstSize, +@@ -553,6 +603,7 @@ HUF_decompress4X1_usingDTable_internal_body( + { + /* Check */ + if (cSrcSize < 10) return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ ++ if (dstSize < 6) return ERROR(corruption_detected); /* stream 4-split doesn't work */ + + { const BYTE* const istart = (const BYTE*) cSrc; + BYTE* const ostart = (BYTE*) dst; +@@ -588,6 +639,7 @@ HUF_decompress4X1_usingDTable_internal_body( + + if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ + if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ ++ assert(dstSize >= 6); /* validated above */ + CHECK_F( BIT_initDStream(&bitD1, istart1, length1) ); + CHECK_F( BIT_initDStream(&bitD2, istart2, length2) ); + CHECK_F( BIT_initDStream(&bitD3, istart3, length3) ); +@@ -650,52 +702,173 @@ size_t HUF_decompress4X1_usingDTable_internal_bmi2(void* dst, size_t dstSize, vo + } + #endif + +-#if HUF_NEED_DEFAULT_FUNCTION + static + size_t HUF_decompress4X1_usingDTable_internal_default(void* dst, size_t dstSize, void const* cSrc, + size_t cSrcSize, HUF_DTable const* DTable) { + return HUF_decompress4X1_usingDTable_internal_body(dst, dstSize, cSrc, cSrcSize, DTable); + } +-#endif + + #if ZSTD_ENABLE_ASM_X86_64_BMI2 + +-HUF_ASM_DECL void HUF_decompress4X1_usingDTable_internal_bmi2_asm_loop(HUF_DecompressAsmArgs* args) ZSTDLIB_HIDDEN; ++HUF_ASM_DECL void HUF_decompress4X1_usingDTable_internal_fast_asm_loop(HUF_DecompressFastArgs* args) ZSTDLIB_HIDDEN; ++ ++#endif ++ ++static HUF_FAST_BMI2_ATTRS ++void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* args) ++{ ++ U64 bits[4]; ++ BYTE const* ip[4]; ++ BYTE* op[4]; ++ U16 const* const dtable = (U16 const*)args->dt; ++ BYTE* const oend = args->oend; ++ BYTE const* const ilowest = args->ilowest; ++ ++ /* Copy the arguments to local variables */ ++ ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); ++ ZSTD_memcpy((void*)(&ip), &args->ip, sizeof(ip)); ++ ZSTD_memcpy(&op, &args->op, sizeof(op)); ++ ++ assert(MEM_isLittleEndian()); ++ assert(!MEM_32bits()); ++ ++ for (;;) { ++ BYTE* olimit; ++ int stream; ++ ++ /* Assert loop preconditions */ ++#ifndef NDEBUG ++ for (stream = 0; stream < 4; ++stream) { ++ assert(op[stream] <= (stream == 3 ? oend : op[stream + 1])); ++ assert(ip[stream] >= ilowest); ++ } ++#endif ++ /* Compute olimit */ ++ { ++ /* Each iteration produces 5 output symbols per stream */ ++ size_t const oiters = (size_t)(oend - op[3]) / 5; ++ /* Each iteration consumes up to 11 bits * 5 = 55 bits < 7 bytes ++ * per stream. ++ */ ++ size_t const iiters = (size_t)(ip[0] - ilowest) / 7; ++ /* We can safely run iters iterations before running bounds checks */ ++ size_t const iters = MIN(oiters, iiters); ++ size_t const symbols = iters * 5; ++ ++ /* We can simply check that op[3] < olimit, instead of checking all ++ * of our bounds, since we can't hit the other bounds until we've run ++ * iters iterations, which only happens when op[3] == olimit. ++ */ ++ olimit = op[3] + symbols; ++ ++ /* Exit fast decoding loop once we reach the end. */ ++ if (op[3] == olimit) ++ break; ++ ++ /* Exit the decoding loop if any input pointer has crossed the ++ * previous one. This indicates corruption, and a precondition ++ * to our loop is that ip[i] >= ip[0]. ++ */ ++ for (stream = 1; stream < 4; ++stream) { ++ if (ip[stream] < ip[stream - 1]) ++ goto _out; ++ } ++ } ++ ++#ifndef NDEBUG ++ for (stream = 1; stream < 4; ++stream) { ++ assert(ip[stream] >= ip[stream - 1]); ++ } ++#endif ++ ++#define HUF_4X1_DECODE_SYMBOL(_stream, _symbol) \ ++ do { \ ++ int const index = (int)(bits[(_stream)] >> 53); \ ++ int const entry = (int)dtable[index]; \ ++ bits[(_stream)] <<= (entry & 0x3F); \ ++ op[(_stream)][(_symbol)] = (BYTE)((entry >> 8) & 0xFF); \ ++ } while (0) ++ ++#define HUF_4X1_RELOAD_STREAM(_stream) \ ++ do { \ ++ int const ctz = ZSTD_countTrailingZeros64(bits[(_stream)]); \ ++ int const nbBits = ctz & 7; \ ++ int const nbBytes = ctz >> 3; \ ++ op[(_stream)] += 5; \ ++ ip[(_stream)] -= nbBytes; \ ++ bits[(_stream)] = MEM_read64(ip[(_stream)]) | 1; \ ++ bits[(_stream)] <<= nbBits; \ ++ } while (0) ++ ++ /* Manually unroll the loop because compilers don't consistently ++ * unroll the inner loops, which destroys performance. ++ */ ++ do { ++ /* Decode 5 symbols in each of the 4 streams */ ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 0); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 1); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 2); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 3); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 4); ++ ++ /* Reload each of the 4 the bitstreams */ ++ HUF_4X_FOR_EACH_STREAM(HUF_4X1_RELOAD_STREAM); ++ } while (op[3] < olimit); ++ ++#undef HUF_4X1_DECODE_SYMBOL ++#undef HUF_4X1_RELOAD_STREAM ++ } + +-static HUF_ASM_X86_64_BMI2_ATTRS ++_out: ++ ++ /* Save the final values of each of the state variables back to args. */ ++ ZSTD_memcpy(&args->bits, &bits, sizeof(bits)); ++ ZSTD_memcpy((void*)(&args->ip), &ip, sizeof(ip)); ++ ZSTD_memcpy(&args->op, &op, sizeof(op)); ++} ++ ++/* ++ * @returns @p dstSize on success (>= 6) ++ * 0 if the fallback implementation should be used ++ * An error if an error occurred ++ */ ++static HUF_FAST_BMI2_ATTRS + size_t +-HUF_decompress4X1_usingDTable_internal_bmi2_asm( ++HUF_decompress4X1_usingDTable_internal_fast( + void* dst, size_t dstSize, + const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) ++ const HUF_DTable* DTable, ++ HUF_DecompressFastLoopFn loopFn) + { + void const* dt = DTable + 1; +- const BYTE* const iend = (const BYTE*)cSrc + 6; +- BYTE* const oend = (BYTE*)dst + dstSize; +- HUF_DecompressAsmArgs args; +- { +- size_t const ret = HUF_DecompressAsmArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); +- FORWARD_IF_ERROR(ret, "Failed to init asm args"); +- if (ret != 0) +- return HUF_decompress4X1_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); ++ BYTE const* const ilowest = (BYTE const*)cSrc; ++ BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); ++ HUF_DecompressFastArgs args; ++ { size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); ++ FORWARD_IF_ERROR(ret, "Failed to init fast loop args"); ++ if (ret == 0) ++ return 0; + } + +- assert(args.ip[0] >= args.ilimit); +- HUF_decompress4X1_usingDTable_internal_bmi2_asm_loop(&args); ++ assert(args.ip[0] >= args.ilowest); ++ loopFn(&args); + +- /* Our loop guarantees that ip[] >= ilimit and that we haven't ++ /* Our loop guarantees that ip[] >= ilowest and that we haven't + * overwritten any op[]. + */ +- assert(args.ip[0] >= iend); +- assert(args.ip[1] >= iend); +- assert(args.ip[2] >= iend); +- assert(args.ip[3] >= iend); ++ assert(args.ip[0] >= ilowest); ++ assert(args.ip[0] >= ilowest); ++ assert(args.ip[1] >= ilowest); ++ assert(args.ip[2] >= ilowest); ++ assert(args.ip[3] >= ilowest); + assert(args.op[3] <= oend); +- (void)iend; ++ ++ assert(ilowest == args.ilowest); ++ assert(ilowest + 6 == args.iend[0]); ++ (void)ilowest; + + /* finish bit streams one by one. */ +- { +- size_t const segmentSize = (dstSize+3) / 4; ++ { size_t const segmentSize = (dstSize+3) / 4; + BYTE* segmentEnd = (BYTE*)dst; + int i; + for (i = 0; i < 4; ++i) { +@@ -712,97 +885,59 @@ HUF_decompress4X1_usingDTable_internal_bmi2_asm( + } + + /* decoded size */ ++ assert(dstSize != 0); + return dstSize; + } +-#endif /* ZSTD_ENABLE_ASM_X86_64_BMI2 */ +- +-typedef size_t (*HUF_decompress_usingDTable_t)(void *dst, size_t dstSize, +- const void *cSrc, +- size_t cSrcSize, +- const HUF_DTable *DTable); + + HUF_DGEN(HUF_decompress1X1_usingDTable_internal) + + static size_t HUF_decompress4X1_usingDTable_internal(void* dst, size_t dstSize, void const* cSrc, +- size_t cSrcSize, HUF_DTable const* DTable, int bmi2) ++ size_t cSrcSize, HUF_DTable const* DTable, int flags) + { ++ HUF_DecompressUsingDTableFn fallbackFn = HUF_decompress4X1_usingDTable_internal_default; ++ HUF_DecompressFastLoopFn loopFn = HUF_decompress4X1_usingDTable_internal_fast_c_loop; ++ + #if DYNAMIC_BMI2 +- if (bmi2) { ++ if (flags & HUF_flags_bmi2) { ++ fallbackFn = HUF_decompress4X1_usingDTable_internal_bmi2; + # if ZSTD_ENABLE_ASM_X86_64_BMI2 +- return HUF_decompress4X1_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); +-# else +- return HUF_decompress4X1_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); ++ if (!(flags & HUF_flags_disableAsm)) { ++ loopFn = HUF_decompress4X1_usingDTable_internal_fast_asm_loop; ++ } + # endif ++ } else { ++ return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); + } +-#else +- (void)bmi2; + #endif + + #if ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__) +- return HUF_decompress4X1_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); +-#else +- return HUF_decompress4X1_usingDTable_internal_default(dst, dstSize, cSrc, cSrcSize, DTable); ++ if (!(flags & HUF_flags_disableAsm)) { ++ loopFn = HUF_decompress4X1_usingDTable_internal_fast_asm_loop; ++ } + #endif +-} +- +- +-size_t HUF_decompress1X1_usingDTable( +- void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) +-{ +- DTableDesc dtd = HUF_getDTableDesc(DTable); +- if (dtd.tableType != 0) return ERROR(GENERIC); +- return HUF_decompress1X1_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-} + +-size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* DCtx, void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize) +-{ +- const BYTE* ip = (const BYTE*) cSrc; +- +- size_t const hSize = HUF_readDTableX1_wksp(DCtx, cSrc, cSrcSize, workSpace, wkspSize); +- if (HUF_isError(hSize)) return hSize; +- if (hSize >= cSrcSize) return ERROR(srcSize_wrong); +- ip += hSize; cSrcSize -= hSize; +- +- return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, /* bmi2 */ 0); +-} +- +- +-size_t HUF_decompress4X1_usingDTable( +- void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) +-{ +- DTableDesc dtd = HUF_getDTableDesc(DTable); +- if (dtd.tableType != 0) return ERROR(GENERIC); +- return HUF_decompress4X1_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); ++ if (HUF_ENABLE_FAST_DECODE && !(flags & HUF_flags_disableFast)) { ++ size_t const ret = HUF_decompress4X1_usingDTable_internal_fast(dst, dstSize, cSrc, cSrcSize, DTable, loopFn); ++ if (ret != 0) ++ return ret; ++ } ++ return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); + } + +-static size_t HUF_decompress4X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, ++static size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, + const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize, int bmi2) ++ void* workSpace, size_t wkspSize, int flags) + { + const BYTE* ip = (const BYTE*) cSrc; + +- size_t const hSize = HUF_readDTableX1_wksp_bmi2(dctx, cSrc, cSrcSize, workSpace, wkspSize, bmi2); ++ size_t const hSize = HUF_readDTableX1_wksp(dctx, cSrc, cSrcSize, workSpace, wkspSize, flags); + if (HUF_isError(hSize)) return hSize; + if (hSize >= cSrcSize) return ERROR(srcSize_wrong); + ip += hSize; cSrcSize -= hSize; + +- return HUF_decompress4X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); +-} +- +-size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize) +-{ +- return HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, 0); ++ return HUF_decompress4X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); + } + +- + #endif /* HUF_FORCE_DECOMPRESS_X2 */ + + +@@ -985,7 +1120,7 @@ static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 targetLog, const U32 + + static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog, + const sortedSymbol_t* sortedList, +- const U32* rankStart, rankValCol_t *rankValOrigin, const U32 maxWeight, ++ const U32* rankStart, rankValCol_t* rankValOrigin, const U32 maxWeight, + const U32 nbBitsBaseline) + { + U32* const rankVal = rankValOrigin[0]; +@@ -1040,14 +1175,7 @@ typedef struct { + + size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, + const void* src, size_t srcSize, +- void* workSpace, size_t wkspSize) +-{ +- return HUF_readDTableX2_wksp_bmi2(DTable, src, srcSize, workSpace, wkspSize, /* bmi2 */ 0); +-} +- +-size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, +- const void* src, size_t srcSize, +- void* workSpace, size_t wkspSize, int bmi2) ++ void* workSpace, size_t wkspSize, int flags) + { + U32 tableLog, maxW, nbSymbols; + DTableDesc dtd = HUF_getDTableDesc(DTable); +@@ -1069,7 +1197,7 @@ size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, + if (maxTableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); + /* ZSTD_memset(weightList, 0, sizeof(weightList)); */ /* is not necessary, even though some analyzer complain ... */ + +- iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), bmi2); ++ iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), flags); + if (HUF_isError(iSize)) return iSize; + + /* check result */ +@@ -1159,15 +1287,19 @@ HUF_decodeLastSymbolX2(void* op, BIT_DStream_t* DStream, const HUF_DEltX2* dt, c + } + + #define HUF_DECODE_SYMBOLX2_0(ptr, DStreamPtr) \ +- ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) ++ do { ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); } while (0) + +-#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ +- if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ +- ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) ++#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ ++ do { \ ++ if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ ++ ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); \ ++ } while (0) + +-#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ +- if (MEM_64bits()) \ +- ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) ++#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ ++ do { \ ++ if (MEM_64bits()) \ ++ ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); \ ++ } while (0) + + HINT_INLINE size_t + HUF_decodeStreamX2(BYTE* p, BIT_DStream_t* bitDPtr, BYTE* const pEnd, +@@ -1227,7 +1359,7 @@ HUF_decompress1X2_usingDTable_internal_body( + + /* decode */ + { BYTE* const ostart = (BYTE*) dst; +- BYTE* const oend = ostart + dstSize; ++ BYTE* const oend = ZSTD_maybeNullPtrAdd(ostart, dstSize); + const void* const dtPtr = DTable+1; /* force compiler to not use strict-aliasing */ + const HUF_DEltX2* const dt = (const HUF_DEltX2*)dtPtr; + DTableDesc const dtd = HUF_getDTableDesc(DTable); +@@ -1240,6 +1372,11 @@ HUF_decompress1X2_usingDTable_internal_body( + /* decoded size */ + return dstSize; + } ++ ++/* HUF_decompress4X2_usingDTable_internal_body(): ++ * Conditions: ++ * @dstSize >= 6 ++ */ + FORCE_INLINE_TEMPLATE size_t + HUF_decompress4X2_usingDTable_internal_body( + void* dst, size_t dstSize, +@@ -1247,6 +1384,7 @@ HUF_decompress4X2_usingDTable_internal_body( + const HUF_DTable* DTable) + { + if (cSrcSize < 10) return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ ++ if (dstSize < 6) return ERROR(corruption_detected); /* stream 4-split doesn't work */ + + { const BYTE* const istart = (const BYTE*) cSrc; + BYTE* const ostart = (BYTE*) dst; +@@ -1280,8 +1418,9 @@ HUF_decompress4X2_usingDTable_internal_body( + DTableDesc const dtd = HUF_getDTableDesc(DTable); + U32 const dtLog = dtd.tableLog; + +- if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ +- if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ ++ if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ ++ if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ ++ assert(dstSize >= 6 /* validated above */); + CHECK_F( BIT_initDStream(&bitD1, istart1, length1) ); + CHECK_F( BIT_initDStream(&bitD2, istart2, length2) ); + CHECK_F( BIT_initDStream(&bitD3, istart3, length3) ); +@@ -1366,44 +1505,191 @@ size_t HUF_decompress4X2_usingDTable_internal_bmi2(void* dst, size_t dstSize, vo + } + #endif + +-#if HUF_NEED_DEFAULT_FUNCTION + static + size_t HUF_decompress4X2_usingDTable_internal_default(void* dst, size_t dstSize, void const* cSrc, + size_t cSrcSize, HUF_DTable const* DTable) { + return HUF_decompress4X2_usingDTable_internal_body(dst, dstSize, cSrc, cSrcSize, DTable); + } +-#endif + + #if ZSTD_ENABLE_ASM_X86_64_BMI2 + +-HUF_ASM_DECL void HUF_decompress4X2_usingDTable_internal_bmi2_asm_loop(HUF_DecompressAsmArgs* args) ZSTDLIB_HIDDEN; ++HUF_ASM_DECL void HUF_decompress4X2_usingDTable_internal_fast_asm_loop(HUF_DecompressFastArgs* args) ZSTDLIB_HIDDEN; ++ ++#endif ++ ++static HUF_FAST_BMI2_ATTRS ++void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* args) ++{ ++ U64 bits[4]; ++ BYTE const* ip[4]; ++ BYTE* op[4]; ++ BYTE* oend[4]; ++ HUF_DEltX2 const* const dtable = (HUF_DEltX2 const*)args->dt; ++ BYTE const* const ilowest = args->ilowest; ++ ++ /* Copy the arguments to local registers. */ ++ ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); ++ ZSTD_memcpy((void*)(&ip), &args->ip, sizeof(ip)); ++ ZSTD_memcpy(&op, &args->op, sizeof(op)); ++ ++ oend[0] = op[1]; ++ oend[1] = op[2]; ++ oend[2] = op[3]; ++ oend[3] = args->oend; ++ ++ assert(MEM_isLittleEndian()); ++ assert(!MEM_32bits()); ++ ++ for (;;) { ++ BYTE* olimit; ++ int stream; ++ ++ /* Assert loop preconditions */ ++#ifndef NDEBUG ++ for (stream = 0; stream < 4; ++stream) { ++ assert(op[stream] <= oend[stream]); ++ assert(ip[stream] >= ilowest); ++ } ++#endif ++ /* Compute olimit */ ++ { ++ /* Each loop does 5 table lookups for each of the 4 streams. ++ * Each table lookup consumes up to 11 bits of input, and produces ++ * up to 2 bytes of output. ++ */ ++ /* We can consume up to 7 bytes of input per iteration per stream. ++ * We also know that each input pointer is >= ip[0]. So we can run ++ * iters loops before running out of input. ++ */ ++ size_t iters = (size_t)(ip[0] - ilowest) / 7; ++ /* Each iteration can produce up to 10 bytes of output per stream. ++ * Each output stream my advance at different rates. So take the ++ * minimum number of safe iterations among all the output streams. ++ */ ++ for (stream = 0; stream < 4; ++stream) { ++ size_t const oiters = (size_t)(oend[stream] - op[stream]) / 10; ++ iters = MIN(iters, oiters); ++ } ++ ++ /* Each iteration produces at least 5 output symbols. So until ++ * op[3] crosses olimit, we know we haven't executed iters ++ * iterations yet. This saves us maintaining an iters counter, ++ * at the expense of computing the remaining # of iterations ++ * more frequently. ++ */ ++ olimit = op[3] + (iters * 5); ++ ++ /* Exit the fast decoding loop once we reach the end. */ ++ if (op[3] == olimit) ++ break; ++ ++ /* Exit the decoding loop if any input pointer has crossed the ++ * previous one. This indicates corruption, and a precondition ++ * to our loop is that ip[i] >= ip[0]. ++ */ ++ for (stream = 1; stream < 4; ++stream) { ++ if (ip[stream] < ip[stream - 1]) ++ goto _out; ++ } ++ } ++ ++#ifndef NDEBUG ++ for (stream = 1; stream < 4; ++stream) { ++ assert(ip[stream] >= ip[stream - 1]); ++ } ++#endif + +-static HUF_ASM_X86_64_BMI2_ATTRS size_t +-HUF_decompress4X2_usingDTable_internal_bmi2_asm( ++#define HUF_4X2_DECODE_SYMBOL(_stream, _decode3) \ ++ do { \ ++ if ((_decode3) || (_stream) != 3) { \ ++ int const index = (int)(bits[(_stream)] >> 53); \ ++ HUF_DEltX2 const entry = dtable[index]; \ ++ MEM_write16(op[(_stream)], entry.sequence); \ ++ bits[(_stream)] <<= (entry.nbBits) & 0x3F; \ ++ op[(_stream)] += (entry.length); \ ++ } \ ++ } while (0) ++ ++#define HUF_4X2_RELOAD_STREAM(_stream) \ ++ do { \ ++ HUF_4X2_DECODE_SYMBOL(3, 1); \ ++ { \ ++ int const ctz = ZSTD_countTrailingZeros64(bits[(_stream)]); \ ++ int const nbBits = ctz & 7; \ ++ int const nbBytes = ctz >> 3; \ ++ ip[(_stream)] -= nbBytes; \ ++ bits[(_stream)] = MEM_read64(ip[(_stream)]) | 1; \ ++ bits[(_stream)] <<= nbBits; \ ++ } \ ++ } while (0) ++ ++ /* Manually unroll the loop because compilers don't consistently ++ * unroll the inner loops, which destroys performance. ++ */ ++ do { ++ /* Decode 5 symbols from each of the first 3 streams. ++ * The final stream will be decoded during the reload phase ++ * to reduce register pressure. ++ */ ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); ++ HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); ++ ++ /* Decode one symbol from the final stream */ ++ HUF_4X2_DECODE_SYMBOL(3, 1); ++ ++ /* Decode 4 symbols from the final stream & reload bitstreams. ++ * The final stream is reloaded last, meaning that all 5 symbols ++ * are decoded from the final stream before it is reloaded. ++ */ ++ HUF_4X_FOR_EACH_STREAM(HUF_4X2_RELOAD_STREAM); ++ } while (op[3] < olimit); ++ } ++ ++#undef HUF_4X2_DECODE_SYMBOL ++#undef HUF_4X2_RELOAD_STREAM ++ ++_out: ++ ++ /* Save the final values of each of the state variables back to args. */ ++ ZSTD_memcpy(&args->bits, &bits, sizeof(bits)); ++ ZSTD_memcpy((void*)(&args->ip), &ip, sizeof(ip)); ++ ZSTD_memcpy(&args->op, &op, sizeof(op)); ++} ++ ++ ++static HUF_FAST_BMI2_ATTRS size_t ++HUF_decompress4X2_usingDTable_internal_fast( + void* dst, size_t dstSize, + const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) { ++ const HUF_DTable* DTable, ++ HUF_DecompressFastLoopFn loopFn) { + void const* dt = DTable + 1; +- const BYTE* const iend = (const BYTE*)cSrc + 6; +- BYTE* const oend = (BYTE*)dst + dstSize; +- HUF_DecompressAsmArgs args; ++ const BYTE* const ilowest = (const BYTE*)cSrc; ++ BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); ++ HUF_DecompressFastArgs args; + { +- size_t const ret = HUF_DecompressAsmArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); ++ size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); + FORWARD_IF_ERROR(ret, "Failed to init asm args"); +- if (ret != 0) +- return HUF_decompress4X2_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); ++ if (ret == 0) ++ return 0; + } + +- assert(args.ip[0] >= args.ilimit); +- HUF_decompress4X2_usingDTable_internal_bmi2_asm_loop(&args); ++ assert(args.ip[0] >= args.ilowest); ++ loopFn(&args); + + /* note : op4 already verified within main loop */ +- assert(args.ip[0] >= iend); +- assert(args.ip[1] >= iend); +- assert(args.ip[2] >= iend); +- assert(args.ip[3] >= iend); ++ assert(args.ip[0] >= ilowest); ++ assert(args.ip[1] >= ilowest); ++ assert(args.ip[2] >= ilowest); ++ assert(args.ip[3] >= ilowest); + assert(args.op[3] <= oend); +- (void)iend; ++ ++ assert(ilowest == args.ilowest); ++ assert(ilowest + 6 == args.iend[0]); ++ (void)ilowest; + + /* finish bitStreams one by one */ + { +@@ -1426,91 +1712,72 @@ HUF_decompress4X2_usingDTable_internal_bmi2_asm( + /* decoded size */ + return dstSize; + } +-#endif /* ZSTD_ENABLE_ASM_X86_64_BMI2 */ + + static size_t HUF_decompress4X2_usingDTable_internal(void* dst, size_t dstSize, void const* cSrc, +- size_t cSrcSize, HUF_DTable const* DTable, int bmi2) ++ size_t cSrcSize, HUF_DTable const* DTable, int flags) + { ++ HUF_DecompressUsingDTableFn fallbackFn = HUF_decompress4X2_usingDTable_internal_default; ++ HUF_DecompressFastLoopFn loopFn = HUF_decompress4X2_usingDTable_internal_fast_c_loop; ++ + #if DYNAMIC_BMI2 +- if (bmi2) { ++ if (flags & HUF_flags_bmi2) { ++ fallbackFn = HUF_decompress4X2_usingDTable_internal_bmi2; + # if ZSTD_ENABLE_ASM_X86_64_BMI2 +- return HUF_decompress4X2_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); +-# else +- return HUF_decompress4X2_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); ++ if (!(flags & HUF_flags_disableAsm)) { ++ loopFn = HUF_decompress4X2_usingDTable_internal_fast_asm_loop; ++ } + # endif ++ } else { ++ return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); + } +-#else +- (void)bmi2; + #endif + + #if ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__) +- return HUF_decompress4X2_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); +-#else +- return HUF_decompress4X2_usingDTable_internal_default(dst, dstSize, cSrc, cSrcSize, DTable); ++ if (!(flags & HUF_flags_disableAsm)) { ++ loopFn = HUF_decompress4X2_usingDTable_internal_fast_asm_loop; ++ } + #endif ++ ++ if (HUF_ENABLE_FAST_DECODE && !(flags & HUF_flags_disableFast)) { ++ size_t const ret = HUF_decompress4X2_usingDTable_internal_fast(dst, dstSize, cSrc, cSrcSize, DTable, loopFn); ++ if (ret != 0) ++ return ret; ++ } ++ return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); + } + + HUF_DGEN(HUF_decompress1X2_usingDTable_internal) + +-size_t HUF_decompress1X2_usingDTable( +- void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) +-{ +- DTableDesc dtd = HUF_getDTableDesc(DTable); +- if (dtd.tableType != 1) return ERROR(GENERIC); +- return HUF_decompress1X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-} +- + size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* DCtx, void* dst, size_t dstSize, + const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize) ++ void* workSpace, size_t wkspSize, int flags) + { + const BYTE* ip = (const BYTE*) cSrc; + + size_t const hSize = HUF_readDTableX2_wksp(DCtx, cSrc, cSrcSize, +- workSpace, wkspSize); ++ workSpace, wkspSize, flags); + if (HUF_isError(hSize)) return hSize; + if (hSize >= cSrcSize) return ERROR(srcSize_wrong); + ip += hSize; cSrcSize -= hSize; + +- return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, /* bmi2 */ 0); ++ return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, flags); + } + +- +-size_t HUF_decompress4X2_usingDTable( +- void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) +-{ +- DTableDesc dtd = HUF_getDTableDesc(DTable); +- if (dtd.tableType != 1) return ERROR(GENERIC); +- return HUF_decompress4X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-} +- +-static size_t HUF_decompress4X2_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, ++static size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, + const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize, int bmi2) ++ void* workSpace, size_t wkspSize, int flags) + { + const BYTE* ip = (const BYTE*) cSrc; + + size_t hSize = HUF_readDTableX2_wksp(dctx, cSrc, cSrcSize, +- workSpace, wkspSize); ++ workSpace, wkspSize, flags); + if (HUF_isError(hSize)) return hSize; + if (hSize >= cSrcSize) return ERROR(srcSize_wrong); + ip += hSize; cSrcSize -= hSize; + +- return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); ++ return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); + } + +-size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, +- const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize) +-{ +- return HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, /* bmi2 */ 0); +-} +- +- + #endif /* HUF_FORCE_DECOMPRESS_X1 */ + + +@@ -1518,44 +1785,6 @@ size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, + /* Universal decompression selectors */ + /* ***********************************/ + +-size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, +- const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) +-{ +- DTableDesc const dtd = HUF_getDTableDesc(DTable); +-#if defined(HUF_FORCE_DECOMPRESS_X1) +- (void)dtd; +- assert(dtd.tableType == 0); +- return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-#elif defined(HUF_FORCE_DECOMPRESS_X2) +- (void)dtd; +- assert(dtd.tableType == 1); +- return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-#else +- return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0) : +- HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-#endif +-} +- +-size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, +- const void* cSrc, size_t cSrcSize, +- const HUF_DTable* DTable) +-{ +- DTableDesc const dtd = HUF_getDTableDesc(DTable); +-#if defined(HUF_FORCE_DECOMPRESS_X1) +- (void)dtd; +- assert(dtd.tableType == 0); +- return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-#elif defined(HUF_FORCE_DECOMPRESS_X2) +- (void)dtd; +- assert(dtd.tableType == 1); +- return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-#else +- return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0) : +- HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); +-#endif +-} +- + + #if !defined(HUF_FORCE_DECOMPRESS_X1) && !defined(HUF_FORCE_DECOMPRESS_X2) + typedef struct { U32 tableTime; U32 decode256Time; } algo_time_t; +@@ -1610,36 +1839,9 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize) + #endif + } + +- +-size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, +- size_t dstSize, const void* cSrc, +- size_t cSrcSize, void* workSpace, +- size_t wkspSize) +-{ +- /* validation checks */ +- if (dstSize == 0) return ERROR(dstSize_tooSmall); +- if (cSrcSize == 0) return ERROR(corruption_detected); +- +- { U32 const algoNb = HUF_selectDecoder(dstSize, cSrcSize); +-#if defined(HUF_FORCE_DECOMPRESS_X1) +- (void)algoNb; +- assert(algoNb == 0); +- return HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); +-#elif defined(HUF_FORCE_DECOMPRESS_X2) +- (void)algoNb; +- assert(algoNb == 1); +- return HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); +-#else +- return algoNb ? HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, +- cSrcSize, workSpace, wkspSize): +- HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); +-#endif +- } +-} +- + size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, + const void* cSrc, size_t cSrcSize, +- void* workSpace, size_t wkspSize) ++ void* workSpace, size_t wkspSize, int flags) + { + /* validation checks */ + if (dstSize == 0) return ERROR(dstSize_tooSmall); +@@ -1652,71 +1854,71 @@ size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, + (void)algoNb; + assert(algoNb == 0); + return HUF_decompress1X1_DCtx_wksp(dctx, dst, dstSize, cSrc, +- cSrcSize, workSpace, wkspSize); ++ cSrcSize, workSpace, wkspSize, flags); + #elif defined(HUF_FORCE_DECOMPRESS_X2) + (void)algoNb; + assert(algoNb == 1); + return HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, +- cSrcSize, workSpace, wkspSize); ++ cSrcSize, workSpace, wkspSize, flags); + #else + return algoNb ? HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, +- cSrcSize, workSpace, wkspSize): ++ cSrcSize, workSpace, wkspSize, flags): + HUF_decompress1X1_DCtx_wksp(dctx, dst, dstSize, cSrc, +- cSrcSize, workSpace, wkspSize); ++ cSrcSize, workSpace, wkspSize, flags); + #endif + } + } + + +-size_t HUF_decompress1X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2) ++size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags) + { + DTableDesc const dtd = HUF_getDTableDesc(DTable); + #if defined(HUF_FORCE_DECOMPRESS_X1) + (void)dtd; + assert(dtd.tableType == 0); +- return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); ++ return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); + #elif defined(HUF_FORCE_DECOMPRESS_X2) + (void)dtd; + assert(dtd.tableType == 1); +- return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); ++ return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); + #else +- return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2) : +- HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); ++ return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags) : ++ HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); + #endif + } + + #ifndef HUF_FORCE_DECOMPRESS_X2 +-size_t HUF_decompress1X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2) ++size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags) + { + const BYTE* ip = (const BYTE*) cSrc; + +- size_t const hSize = HUF_readDTableX1_wksp_bmi2(dctx, cSrc, cSrcSize, workSpace, wkspSize, bmi2); ++ size_t const hSize = HUF_readDTableX1_wksp(dctx, cSrc, cSrcSize, workSpace, wkspSize, flags); + if (HUF_isError(hSize)) return hSize; + if (hSize >= cSrcSize) return ERROR(srcSize_wrong); + ip += hSize; cSrcSize -= hSize; + +- return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); ++ return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); + } + #endif + +-size_t HUF_decompress4X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2) ++size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags) + { + DTableDesc const dtd = HUF_getDTableDesc(DTable); + #if defined(HUF_FORCE_DECOMPRESS_X1) + (void)dtd; + assert(dtd.tableType == 0); +- return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); ++ return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); + #elif defined(HUF_FORCE_DECOMPRESS_X2) + (void)dtd; + assert(dtd.tableType == 1); +- return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); ++ return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); + #else +- return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2) : +- HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); ++ return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags) : ++ HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); + #endif + } + +-size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2) ++size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags) + { + /* validation checks */ + if (dstSize == 0) return ERROR(dstSize_tooSmall); +@@ -1726,15 +1928,14 @@ size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t ds + #if defined(HUF_FORCE_DECOMPRESS_X1) + (void)algoNb; + assert(algoNb == 0); +- return HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); ++ return HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); + #elif defined(HUF_FORCE_DECOMPRESS_X2) + (void)algoNb; + assert(algoNb == 1); +- return HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); ++ return HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); + #else +- return algoNb ? HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2) : +- HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); ++ return algoNb ? HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags) : ++ HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); + #endif + } + } +- +diff --git a/lib/zstd/decompress/zstd_ddict.c b/lib/zstd/decompress/zstd_ddict.c +index dbbc7919de53..30ef65e1ab5c 100644 +--- a/lib/zstd/decompress/zstd_ddict.c ++++ b/lib/zstd/decompress/zstd_ddict.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -14,12 +15,12 @@ + /*-******************************************************* + * Dependencies + *********************************************************/ ++#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customFree */ + #include "../common/zstd_deps.h" /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */ + #include "../common/cpu.h" /* bmi2 */ + #include "../common/mem.h" /* low level memory routines */ + #define FSE_STATIC_LINKING_ONLY + #include "../common/fse.h" +-#define HUF_STATIC_LINKING_ONLY + #include "../common/huf.h" + #include "zstd_decompress_internal.h" + #include "zstd_ddict.h" +@@ -131,7 +132,7 @@ static size_t ZSTD_initDDict_internal(ZSTD_DDict* ddict, + ZSTD_memcpy(internalBuffer, dict, dictSize); + } + ddict->dictSize = dictSize; +- ddict->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ ++ ddict->entropy.hufTable[0] = (HUF_DTable)((ZSTD_HUFFDTABLE_CAPACITY_LOG)*0x1000001); /* cover both little and big endian */ + + /* parse dictionary content */ + FORWARD_IF_ERROR( ZSTD_loadEntropy_intoDDict(ddict, dictContentType) , ""); +@@ -237,5 +238,5 @@ size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict) + unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict* ddict) + { + if (ddict==NULL) return 0; +- return ZSTD_getDictID_fromDict(ddict->dictContent, ddict->dictSize); ++ return ddict->dictID; + } +diff --git a/lib/zstd/decompress/zstd_ddict.h b/lib/zstd/decompress/zstd_ddict.h +index 8c1a79d666f8..de459a0dacd1 100644 +--- a/lib/zstd/decompress/zstd_ddict.h ++++ b/lib/zstd/decompress/zstd_ddict.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/decompress/zstd_decompress.c b/lib/zstd/decompress/zstd_decompress.c +index 6b3177c94711..c9cbc45f6ed9 100644 +--- a/lib/zstd/decompress/zstd_decompress.c ++++ b/lib/zstd/decompress/zstd_decompress.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -53,13 +54,15 @@ + * Dependencies + *********************************************************/ + #include "../common/zstd_deps.h" /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */ ++#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */ ++#include "../common/error_private.h" ++#include "../common/zstd_internal.h" /* blockProperties_t */ + #include "../common/mem.h" /* low level memory routines */ ++#include "../common/bits.h" /* ZSTD_highbit32 */ + #define FSE_STATIC_LINKING_ONLY + #include "../common/fse.h" +-#define HUF_STATIC_LINKING_ONLY + #include "../common/huf.h" + #include /* xxh64_reset, xxh64_update, xxh64_digest, XXH64 */ +-#include "../common/zstd_internal.h" /* blockProperties_t */ + #include "zstd_decompress_internal.h" /* ZSTD_DCtx */ + #include "zstd_ddict.h" /* ZSTD_DDictDictContent */ + #include "zstd_decompress_block.h" /* ZSTD_decompressBlock_internal */ +@@ -72,11 +75,11 @@ + *************************************/ + + #define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4 +-#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float. +- * Currently, that means a 0.75 load factor. +- * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded +- * the load factor of the ddict hash set. +- */ ++#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float. ++ * Currently, that means a 0.75 load factor. ++ * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded ++ * the load factor of the ddict hash set. ++ */ + + #define DDICT_HASHSET_TABLE_BASE_SIZE 64 + #define DDICT_HASHSET_RESIZE_FACTOR 2 +@@ -237,6 +240,8 @@ static void ZSTD_DCtx_resetParameters(ZSTD_DCtx* dctx) + dctx->outBufferMode = ZSTD_bm_buffered; + dctx->forceIgnoreChecksum = ZSTD_d_validateChecksum; + dctx->refMultipleDDicts = ZSTD_rmd_refSingleDDict; ++ dctx->disableHufAsm = 0; ++ dctx->maxBlockSizeParam = 0; + } + + static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) +@@ -253,6 +258,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) + dctx->streamStage = zdss_init; + dctx->noForwardProgress = 0; + dctx->oversizedDuration = 0; ++ dctx->isFrameDecompression = 1; + #if DYNAMIC_BMI2 + dctx->bmi2 = ZSTD_cpuSupportsBmi2(); + #endif +@@ -421,16 +427,40 @@ size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize) + * note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless + * @return : 0, `zfhPtr` is correctly filled, + * >0, `srcSize` is too small, value is wanted `srcSize` amount, +- * or an error code, which can be tested using ZSTD_isError() */ ++** or an error code, which can be tested using ZSTD_isError() */ + size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format) + { + const BYTE* ip = (const BYTE*)src; + size_t const minInputSize = ZSTD_startingInputLength(format); + +- ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */ +- if (srcSize < minInputSize) return minInputSize; +- RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter"); ++ DEBUGLOG(5, "ZSTD_getFrameHeader_advanced: minInputSize = %zu, srcSize = %zu", minInputSize, srcSize); ++ ++ if (srcSize > 0) { ++ /* note : technically could be considered an assert(), since it's an invalid entry */ ++ RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter : src==NULL, but srcSize>0"); ++ } ++ if (srcSize < minInputSize) { ++ if (srcSize > 0 && format != ZSTD_f_zstd1_magicless) { ++ /* when receiving less than @minInputSize bytes, ++ * control these bytes at least correspond to a supported magic number ++ * in order to error out early if they don't. ++ **/ ++ size_t const toCopy = MIN(4, srcSize); ++ unsigned char hbuf[4]; MEM_writeLE32(hbuf, ZSTD_MAGICNUMBER); ++ assert(src != NULL); ++ ZSTD_memcpy(hbuf, src, toCopy); ++ if ( MEM_readLE32(hbuf) != ZSTD_MAGICNUMBER ) { ++ /* not a zstd frame : let's check if it's a skippable frame */ ++ MEM_writeLE32(hbuf, ZSTD_MAGIC_SKIPPABLE_START); ++ ZSTD_memcpy(hbuf, src, toCopy); ++ if ((MEM_readLE32(hbuf) & ZSTD_MAGIC_SKIPPABLE_MASK) != ZSTD_MAGIC_SKIPPABLE_START) { ++ RETURN_ERROR(prefix_unknown, ++ "first bytes don't correspond to any supported magic number"); ++ } } } ++ return minInputSize; ++ } + ++ ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzers may not understand that zfhPtr will be read only if return value is zero, since they are 2 different signals */ + if ( (format != ZSTD_f_zstd1_magicless) + && (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) { + if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { +@@ -540,61 +570,62 @@ static size_t readSkippableFrameSize(void const* src, size_t srcSize) + sizeU32 = MEM_readLE32((BYTE const*)src + ZSTD_FRAMEIDSIZE); + RETURN_ERROR_IF((U32)(sizeU32 + ZSTD_SKIPPABLEHEADERSIZE) < sizeU32, + frameParameter_unsupported, ""); +- { +- size_t const skippableSize = skippableHeaderSize + sizeU32; ++ { size_t const skippableSize = skippableHeaderSize + sizeU32; + RETURN_ERROR_IF(skippableSize > srcSize, srcSize_wrong, ""); + return skippableSize; + } + } + + /*! ZSTD_readSkippableFrame() : +- * Retrieves a zstd skippable frame containing data given by src, and writes it to dst buffer. ++ * Retrieves content of a skippable frame, and writes it to dst buffer. + * + * The parameter magicVariant will receive the magicVariant that was supplied when the frame was written, + * i.e. magicNumber - ZSTD_MAGIC_SKIPPABLE_START. This can be NULL if the caller is not interested + * in the magicVariant. + * +- * Returns an error if destination buffer is not large enough, or if the frame is not skippable. ++ * Returns an error if destination buffer is not large enough, or if this is not a valid skippable frame. + * + * @return : number of bytes written or a ZSTD error. + */ +-ZSTDLIB_API size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, unsigned* magicVariant, +- const void* src, size_t srcSize) ++size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, ++ unsigned* magicVariant, /* optional, can be NULL */ ++ const void* src, size_t srcSize) + { +- U32 const magicNumber = MEM_readLE32(src); +- size_t skippableFrameSize = readSkippableFrameSize(src, srcSize); +- size_t skippableContentSize = skippableFrameSize - ZSTD_SKIPPABLEHEADERSIZE; +- +- /* check input validity */ +- RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, ""); +- RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, ""); +- RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, ""); ++ RETURN_ERROR_IF(srcSize < ZSTD_SKIPPABLEHEADERSIZE, srcSize_wrong, ""); + +- /* deliver payload */ +- if (skippableContentSize > 0 && dst != NULL) +- ZSTD_memcpy(dst, (const BYTE *)src + ZSTD_SKIPPABLEHEADERSIZE, skippableContentSize); +- if (magicVariant != NULL) +- *magicVariant = magicNumber - ZSTD_MAGIC_SKIPPABLE_START; +- return skippableContentSize; ++ { U32 const magicNumber = MEM_readLE32(src); ++ size_t skippableFrameSize = readSkippableFrameSize(src, srcSize); ++ size_t skippableContentSize = skippableFrameSize - ZSTD_SKIPPABLEHEADERSIZE; ++ ++ /* check input validity */ ++ RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, ""); ++ RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, ""); ++ RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, ""); ++ ++ /* deliver payload */ ++ if (skippableContentSize > 0 && dst != NULL) ++ ZSTD_memcpy(dst, (const BYTE *)src + ZSTD_SKIPPABLEHEADERSIZE, skippableContentSize); ++ if (magicVariant != NULL) ++ *magicVariant = magicNumber - ZSTD_MAGIC_SKIPPABLE_START; ++ return skippableContentSize; ++ } + } + + /* ZSTD_findDecompressedSize() : +- * compatible with legacy mode + * `srcSize` must be the exact length of some number of ZSTD compressed and/or + * skippable frames +- * @return : decompressed size of the frames contained */ ++ * note: compatible with legacy mode ++ * @return : decompressed size of the frames contained */ + unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) + { +- unsigned long long totalDstSize = 0; ++ U64 totalDstSize = 0; + + while (srcSize >= ZSTD_startingInputLength(ZSTD_f_zstd1)) { + U32 const magicNumber = MEM_readLE32(src); + + if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { + size_t const skippableSize = readSkippableFrameSize(src, srcSize); +- if (ZSTD_isError(skippableSize)) { +- return ZSTD_CONTENTSIZE_ERROR; +- } ++ if (ZSTD_isError(skippableSize)) return ZSTD_CONTENTSIZE_ERROR; + assert(skippableSize <= srcSize); + + src = (const BYTE *)src + skippableSize; +@@ -602,17 +633,17 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) + continue; + } + +- { unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize); +- if (ret >= ZSTD_CONTENTSIZE_ERROR) return ret; ++ { unsigned long long const fcs = ZSTD_getFrameContentSize(src, srcSize); ++ if (fcs >= ZSTD_CONTENTSIZE_ERROR) return fcs; + +- /* check for overflow */ +- if (totalDstSize + ret < totalDstSize) return ZSTD_CONTENTSIZE_ERROR; +- totalDstSize += ret; ++ if (U64_MAX - totalDstSize < fcs) ++ return ZSTD_CONTENTSIZE_ERROR; /* check for overflow */ ++ totalDstSize += fcs; + } ++ /* skip to next frame */ + { size_t const frameSrcSize = ZSTD_findFrameCompressedSize(src, srcSize); +- if (ZSTD_isError(frameSrcSize)) { +- return ZSTD_CONTENTSIZE_ERROR; +- } ++ if (ZSTD_isError(frameSrcSize)) return ZSTD_CONTENTSIZE_ERROR; ++ assert(frameSrcSize <= srcSize); + + src = (const BYTE *)src + frameSrcSize; + srcSize -= frameSrcSize; +@@ -676,13 +707,13 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret) + return frameSizeInfo; + } + +-static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize) ++static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format) + { + ZSTD_frameSizeInfo frameSizeInfo; + ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo)); + + +- if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE) ++ if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE) + && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { + frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize); + assert(ZSTD_isError(frameSizeInfo.compressedSize) || +@@ -696,7 +727,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize + ZSTD_frameHeader zfh; + + /* Extract Frame Header */ +- { size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize); ++ { size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format); + if (ZSTD_isError(ret)) + return ZSTD_errorFrameSizeInfo(ret); + if (ret > 0) +@@ -730,23 +761,26 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize + ip += 4; + } + ++ frameSizeInfo.nbBlocks = nbBlocks; + frameSizeInfo.compressedSize = (size_t)(ip - ipstart); + frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) + ? zfh.frameContentSize +- : nbBlocks * zfh.blockSizeMax; ++ : (unsigned long long)nbBlocks * zfh.blockSizeMax; + return frameSizeInfo; + } + } + ++static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) { ++ ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format); ++ return frameSizeInfo.compressedSize; ++} ++ + /* ZSTD_findFrameCompressedSize() : +- * compatible with legacy mode +- * `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame +- * `srcSize` must be at least as large as the frame contained +- * @return : the compressed size of the frame starting at `src` */ ++ * See docs in zstd.h ++ * Note: compatible with legacy mode */ + size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize) + { +- ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); +- return frameSizeInfo.compressedSize; ++ return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1); + } + + /* ZSTD_decompressBound() : +@@ -760,7 +794,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) + unsigned long long bound = 0; + /* Iterate over each frame */ + while (srcSize > 0) { +- ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); ++ ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); + size_t const compressedSize = frameSizeInfo.compressedSize; + unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; + if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) +@@ -773,6 +807,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) + return bound; + } + ++size_t ZSTD_decompressionMargin(void const* src, size_t srcSize) ++{ ++ size_t margin = 0; ++ unsigned maxBlockSize = 0; ++ ++ /* Iterate over each frame */ ++ while (srcSize > 0) { ++ ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); ++ size_t const compressedSize = frameSizeInfo.compressedSize; ++ unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; ++ ZSTD_frameHeader zfh; ++ ++ FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), ""); ++ if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) ++ return ERROR(corruption_detected); ++ ++ if (zfh.frameType == ZSTD_frame) { ++ /* Add the frame header to our margin */ ++ margin += zfh.headerSize; ++ /* Add the checksum to our margin */ ++ margin += zfh.checksumFlag ? 4 : 0; ++ /* Add 3 bytes per block */ ++ margin += 3 * frameSizeInfo.nbBlocks; ++ ++ /* Compute the max block size */ ++ maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax); ++ } else { ++ assert(zfh.frameType == ZSTD_skippableFrame); ++ /* Add the entire skippable frame size to our margin. */ ++ margin += compressedSize; ++ } ++ ++ assert(srcSize >= compressedSize); ++ src = (const BYTE*)src + compressedSize; ++ srcSize -= compressedSize; ++ } ++ ++ /* Add the max block size back to the margin. */ ++ margin += maxBlockSize; ++ ++ return margin; ++} + + /*-************************************************************* + * Frame decoding +@@ -856,6 +932,10 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, + ip += frameHeaderSize; remainingSrcSize -= frameHeaderSize; + } + ++ /* Shrink the blockSizeMax if enabled */ ++ if (dctx->maxBlockSizeParam != 0) ++ dctx->fParams.blockSizeMax = MIN(dctx->fParams.blockSizeMax, (unsigned)dctx->maxBlockSizeParam); ++ + /* Loop on each block */ + while (1) { + BYTE* oBlockEnd = oend; +@@ -888,7 +968,8 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, + switch(blockProperties.blockType) + { + case bt_compressed: +- decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming); ++ assert(dctx->isFrameDecompression == 1); ++ decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, not_streaming); + break; + case bt_raw : + /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */ +@@ -901,12 +982,14 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, + default: + RETURN_ERROR(corruption_detected, "invalid block type"); + } +- +- if (ZSTD_isError(decodedSize)) return decodedSize; +- if (dctx->validateChecksum) ++ FORWARD_IF_ERROR(decodedSize, "Block decompression failure"); ++ DEBUGLOG(5, "Decompressed block of dSize = %u", (unsigned)decodedSize); ++ if (dctx->validateChecksum) { + xxh64_update(&dctx->xxhState, op, decodedSize); +- if (decodedSize != 0) ++ } ++ if (decodedSize) /* support dst = NULL,0 */ { + op += decodedSize; ++ } + assert(ip != NULL); + ip += cBlockSize; + remainingSrcSize -= cBlockSize; +@@ -930,12 +1013,15 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, + } + ZSTD_DCtx_trace_end(dctx, (U64)(op-ostart), (U64)(ip-istart), /* streaming */ 0); + /* Allow caller to get size read */ ++ DEBUGLOG(4, "ZSTD_decompressFrame: decompressed frame of size %zi, consuming %zi bytes of input", op-ostart, ip - (const BYTE*)*srcPtr); + *srcPtr = ip; + *srcSizePtr = remainingSrcSize; + return (size_t)(op-ostart); + } + +-static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, ++static ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR ++size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + const void* dict, size_t dictSize, +@@ -955,17 +1041,18 @@ static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, + while (srcSize >= ZSTD_startingInputLength(dctx->format)) { + + +- { U32 const magicNumber = MEM_readLE32(src); +- DEBUGLOG(4, "reading magic number %08X (expecting %08X)", +- (unsigned)magicNumber, ZSTD_MAGICNUMBER); ++ if (dctx->format == ZSTD_f_zstd1 && srcSize >= 4) { ++ U32 const magicNumber = MEM_readLE32(src); ++ DEBUGLOG(5, "reading magic number %08X", (unsigned)magicNumber); + if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { ++ /* skippable frame detected : skip it */ + size_t const skippableSize = readSkippableFrameSize(src, srcSize); +- FORWARD_IF_ERROR(skippableSize, "readSkippableFrameSize failed"); ++ FORWARD_IF_ERROR(skippableSize, "invalid skippable frame"); + assert(skippableSize <= srcSize); + + src = (const BYTE *)src + skippableSize; + srcSize -= skippableSize; +- continue; ++ continue; /* check next frame */ + } } + + if (ddict) { +@@ -1061,8 +1148,8 @@ size_t ZSTD_decompress(void* dst, size_t dstCapacity, const void* src, size_t sr + size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx) { return dctx->expected; } + + /* +- * Similar to ZSTD_nextSrcSizeToDecompress(), but when a block input can be streamed, +- * we allow taking a partial block as the input. Currently only raw uncompressed blocks can ++ * Similar to ZSTD_nextSrcSizeToDecompress(), but when a block input can be streamed, we ++ * allow taking a partial block as the input. Currently only raw uncompressed blocks can + * be streamed. + * + * For blocks that can be streamed, this allows us to reduce the latency until we produce +@@ -1181,7 +1268,8 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c + { + case bt_compressed: + DEBUGLOG(5, "ZSTD_decompressContinue: case bt_compressed"); +- rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 1, is_streaming); ++ assert(dctx->isFrameDecompression == 1); ++ rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, is_streaming); + dctx->expected = 0; /* Streaming not supported */ + break; + case bt_raw : +@@ -1250,6 +1338,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c + case ZSTDds_decodeSkippableHeader: + assert(src != NULL); + assert(srcSize <= ZSTD_SKIPPABLEHEADERSIZE); ++ assert(dctx->format != ZSTD_f_zstd1_magicless); + ZSTD_memcpy(dctx->headerBuffer + (ZSTD_SKIPPABLEHEADERSIZE - srcSize), src, srcSize); /* complete skippable header */ + dctx->expected = MEM_readLE32(dctx->headerBuffer + ZSTD_FRAMEIDSIZE); /* note : dctx->expected can grow seriously large, beyond local buffer size */ + dctx->stage = ZSTDds_skipFrame; +@@ -1262,7 +1351,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c + + default: + assert(0); /* impossible */ +- RETURN_ERROR(GENERIC, "impossible to reach"); /* some compiler require default to do something */ ++ RETURN_ERROR(GENERIC, "impossible to reach"); /* some compilers require default to do something */ + } + } + +@@ -1303,11 +1392,11 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, + /* in minimal huffman, we always use X1 variants */ + size_t const hSize = HUF_readDTableX1_wksp(entropy->hufTable, + dictPtr, dictEnd - dictPtr, +- workspace, workspaceSize); ++ workspace, workspaceSize, /* flags */ 0); + #else + size_t const hSize = HUF_readDTableX2_wksp(entropy->hufTable, + dictPtr, (size_t)(dictEnd - dictPtr), +- workspace, workspaceSize); ++ workspace, workspaceSize, /* flags */ 0); + #endif + RETURN_ERROR_IF(HUF_isError(hSize), dictionary_corrupted, ""); + dictPtr += hSize; +@@ -1403,10 +1492,11 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx) + dctx->prefixStart = NULL; + dctx->virtualStart = NULL; + dctx->dictEnd = NULL; +- dctx->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ ++ dctx->entropy.hufTable[0] = (HUF_DTable)((ZSTD_HUFFDTABLE_CAPACITY_LOG)*0x1000001); /* cover both little and big endian */ + dctx->litEntropy = dctx->fseEntropy = 0; + dctx->dictID = 0; + dctx->bType = bt_reserved; ++ dctx->isFrameDecompression = 1; + ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.rep) == sizeof(repStartValue)); + ZSTD_memcpy(dctx->entropy.rep, repStartValue, sizeof(repStartValue)); /* initial repcodes */ + dctx->LLTptr = dctx->entropy.LLTable; +@@ -1465,7 +1555,7 @@ unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize) + * This could for one of the following reasons : + * - The frame does not require a dictionary (most common case). + * - The frame was built with dictID intentionally removed. +- * Needed dictionary is a hidden information. ++ * Needed dictionary is a hidden piece of information. + * Note : this use case also happens when using a non-conformant dictionary. + * - `srcSize` is too small, and as a result, frame header could not be decoded. + * Note : possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`. +@@ -1474,7 +1564,7 @@ unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize) + * ZSTD_getFrameHeader(), which will provide a more precise error code. */ + unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize) + { +- ZSTD_frameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0 }; ++ ZSTD_frameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0, 0, 0 }; + size_t const hError = ZSTD_getFrameHeader(&zfp, src, srcSize); + if (ZSTD_isError(hError)) return 0; + return zfp.dictID; +@@ -1581,7 +1671,9 @@ size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t di + size_t ZSTD_initDStream(ZSTD_DStream* zds) + { + DEBUGLOG(4, "ZSTD_initDStream"); +- return ZSTD_initDStream_usingDDict(zds, NULL); ++ FORWARD_IF_ERROR(ZSTD_DCtx_reset(zds, ZSTD_reset_session_only), ""); ++ FORWARD_IF_ERROR(ZSTD_DCtx_refDDict(zds, NULL), ""); ++ return ZSTD_startingInputLength(zds->format); + } + + /* ZSTD_initDStream_usingDDict() : +@@ -1589,6 +1681,7 @@ size_t ZSTD_initDStream(ZSTD_DStream* zds) + * this function cannot fail */ + size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict) + { ++ DEBUGLOG(4, "ZSTD_initDStream_usingDDict"); + FORWARD_IF_ERROR( ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only) , ""); + FORWARD_IF_ERROR( ZSTD_DCtx_refDDict(dctx, ddict) , ""); + return ZSTD_startingInputLength(dctx->format); +@@ -1599,6 +1692,7 @@ size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict) + * this function cannot fail */ + size_t ZSTD_resetDStream(ZSTD_DStream* dctx) + { ++ DEBUGLOG(4, "ZSTD_resetDStream"); + FORWARD_IF_ERROR(ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only), ""); + return ZSTD_startingInputLength(dctx->format); + } +@@ -1670,6 +1764,15 @@ ZSTD_bounds ZSTD_dParam_getBounds(ZSTD_dParameter dParam) + bounds.lowerBound = (int)ZSTD_rmd_refSingleDDict; + bounds.upperBound = (int)ZSTD_rmd_refMultipleDDicts; + return bounds; ++ case ZSTD_d_disableHuffmanAssembly: ++ bounds.lowerBound = 0; ++ bounds.upperBound = 1; ++ return bounds; ++ case ZSTD_d_maxBlockSize: ++ bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN; ++ bounds.upperBound = ZSTD_BLOCKSIZE_MAX; ++ return bounds; ++ + default:; + } + bounds.error = ERROR(parameter_unsupported); +@@ -1710,6 +1813,12 @@ size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParameter param, int* value + case ZSTD_d_refMultipleDDicts: + *value = (int)dctx->refMultipleDDicts; + return 0; ++ case ZSTD_d_disableHuffmanAssembly: ++ *value = (int)dctx->disableHufAsm; ++ return 0; ++ case ZSTD_d_maxBlockSize: ++ *value = dctx->maxBlockSizeParam; ++ return 0; + default:; + } + RETURN_ERROR(parameter_unsupported, ""); +@@ -1743,6 +1852,14 @@ size_t ZSTD_DCtx_setParameter(ZSTD_DCtx* dctx, ZSTD_dParameter dParam, int value + } + dctx->refMultipleDDicts = (ZSTD_refMultipleDDicts_e)value; + return 0; ++ case ZSTD_d_disableHuffmanAssembly: ++ CHECK_DBOUNDS(ZSTD_d_disableHuffmanAssembly, value); ++ dctx->disableHufAsm = value != 0; ++ return 0; ++ case ZSTD_d_maxBlockSize: ++ if (value != 0) CHECK_DBOUNDS(ZSTD_d_maxBlockSize, value); ++ dctx->maxBlockSizeParam = value; ++ return 0; + default:; + } + RETURN_ERROR(parameter_unsupported, ""); +@@ -1754,6 +1871,7 @@ size_t ZSTD_DCtx_reset(ZSTD_DCtx* dctx, ZSTD_ResetDirective reset) + || (reset == ZSTD_reset_session_and_parameters) ) { + dctx->streamStage = zdss_init; + dctx->noForwardProgress = 0; ++ dctx->isFrameDecompression = 1; + } + if ( (reset == ZSTD_reset_parameters) + || (reset == ZSTD_reset_session_and_parameters) ) { +@@ -1770,11 +1888,17 @@ size_t ZSTD_sizeof_DStream(const ZSTD_DStream* dctx) + return ZSTD_sizeof_DCtx(dctx); + } + +-size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize) ++static size_t ZSTD_decodingBufferSize_internal(unsigned long long windowSize, unsigned long long frameContentSize, size_t blockSizeMax) + { +- size_t const blockSize = (size_t) MIN(windowSize, ZSTD_BLOCKSIZE_MAX); +- /* space is needed to store the litbuffer after the output of a given block without stomping the extDict of a previous run, as well as to cover both windows against wildcopy*/ +- unsigned long long const neededRBSize = windowSize + blockSize + ZSTD_BLOCKSIZE_MAX + (WILDCOPY_OVERLENGTH * 2); ++ size_t const blockSize = MIN((size_t)MIN(windowSize, ZSTD_BLOCKSIZE_MAX), blockSizeMax); ++ /* We need blockSize + WILDCOPY_OVERLENGTH worth of buffer so that if a block ++ * ends at windowSize + WILDCOPY_OVERLENGTH + 1 bytes, we can start writing ++ * the block at the beginning of the output buffer, and maintain a full window. ++ * ++ * We need another blockSize worth of buffer so that we can store split ++ * literals at the end of the block without overwriting the extDict window. ++ */ ++ unsigned long long const neededRBSize = windowSize + (blockSize * 2) + (WILDCOPY_OVERLENGTH * 2); + unsigned long long const neededSize = MIN(frameContentSize, neededRBSize); + size_t const minRBSize = (size_t) neededSize; + RETURN_ERROR_IF((unsigned long long)minRBSize != neededSize, +@@ -1782,6 +1906,11 @@ size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long + return minRBSize; + } + ++size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize) ++{ ++ return ZSTD_decodingBufferSize_internal(windowSize, frameContentSize, ZSTD_BLOCKSIZE_MAX); ++} ++ + size_t ZSTD_estimateDStreamSize(size_t windowSize) + { + size_t const blockSize = MIN(windowSize, ZSTD_BLOCKSIZE_MAX); +@@ -1918,7 +2047,6 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + if (zds->refMultipleDDicts && zds->ddictSet) { + ZSTD_DCtx_selectFrameDDict(zds); + } +- DEBUGLOG(5, "header size : %u", (U32)hSize); + if (ZSTD_isError(hSize)) { + return hSize; /* error */ + } +@@ -1932,6 +2060,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + zds->lhSize += remainingInput; + } + input->pos = input->size; ++ /* check first few bytes */ ++ FORWARD_IF_ERROR( ++ ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format), ++ "First few bytes detected incorrect" ); ++ /* return hint input size */ + return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */ + } + assert(ip != NULL); +@@ -1943,14 +2076,15 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN + && zds->fParams.frameType != ZSTD_skippableFrame + && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) { +- size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart)); ++ size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format); + if (cSize <= (size_t)(iend-istart)) { + /* shortcut : using single-pass mode */ + size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds)); + if (ZSTD_isError(decompressedSize)) return decompressedSize; +- DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()") ++ DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()"); ++ assert(istart != NULL); + ip = istart + cSize; +- op += decompressedSize; ++ op = op ? op + decompressedSize : op; /* can occur if frameContentSize = 0 (empty frame) */ + zds->expected = 0; + zds->streamStage = zdss_init; + someMoreWork = 0; +@@ -1969,7 +2103,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + DEBUGLOG(4, "Consume header"); + FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(zds, ZSTD_getDDict(zds)), ""); + +- if ((MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ ++ if (zds->format == ZSTD_f_zstd1 ++ && (MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ + zds->expected = MEM_readLE32(zds->headerBuffer + ZSTD_FRAMEIDSIZE); + zds->stage = ZSTDds_skipFrame; + } else { +@@ -1985,11 +2120,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN); + RETURN_ERROR_IF(zds->fParams.windowSize > zds->maxWindowSize, + frameParameter_windowTooLarge, ""); ++ if (zds->maxBlockSizeParam != 0) ++ zds->fParams.blockSizeMax = MIN(zds->fParams.blockSizeMax, (unsigned)zds->maxBlockSizeParam); + + /* Adapt buffer sizes to frame header instructions */ + { size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */); + size_t const neededOutBuffSize = zds->outBufferMode == ZSTD_bm_buffered +- ? ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize) ++ ? ZSTD_decodingBufferSize_internal(zds->fParams.windowSize, zds->fParams.frameContentSize, zds->fParams.blockSizeMax) + : 0; + + ZSTD_DCtx_updateOversizedDuration(zds, neededInBuffSize, neededOutBuffSize); +@@ -2034,6 +2171,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + } + if ((size_t)(iend-ip) >= neededInSize) { /* decode directly from src */ + FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, ip, neededInSize), ""); ++ assert(ip != NULL); + ip += neededInSize; + /* Function modifies the stage so we must break */ + break; +@@ -2048,7 +2186,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + int const isSkipFrame = ZSTD_isSkipFrame(zds); + size_t loadedSize; + /* At this point we shouldn't be decompressing a block that we can stream. */ +- assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, iend - ip)); ++ assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, (size_t)(iend - ip))); + if (isSkipFrame) { + loadedSize = MIN(toLoad, (size_t)(iend-ip)); + } else { +@@ -2057,8 +2195,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + "should never happen"); + loadedSize = ZSTD_limitCopy(zds->inBuff + zds->inPos, toLoad, ip, (size_t)(iend-ip)); + } +- ip += loadedSize; +- zds->inPos += loadedSize; ++ if (loadedSize != 0) { ++ /* ip may be NULL */ ++ ip += loadedSize; ++ zds->inPos += loadedSize; ++ } + if (loadedSize < toLoad) { someMoreWork = 0; break; } /* not enough input, wait for more */ + + /* decode loaded input */ +@@ -2068,14 +2209,17 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + break; + } + case zdss_flush: +- { size_t const toFlushSize = zds->outEnd - zds->outStart; ++ { ++ size_t const toFlushSize = zds->outEnd - zds->outStart; + size_t const flushedSize = ZSTD_limitCopy(op, (size_t)(oend-op), zds->outBuff + zds->outStart, toFlushSize); +- op += flushedSize; ++ ++ op = op ? op + flushedSize : op; ++ + zds->outStart += flushedSize; + if (flushedSize == toFlushSize) { /* flush completed */ + zds->streamStage = zdss_read; + if ( (zds->outBuffSize < zds->fParams.frameContentSize) +- && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { ++ && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { + DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)", + (int)(zds->outBuffSize - zds->outStart), + (U32)zds->fParams.blockSizeMax); +@@ -2089,7 +2233,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + + default: + assert(0); /* impossible */ +- RETURN_ERROR(GENERIC, "impossible to reach"); /* some compiler require default to do something */ ++ RETURN_ERROR(GENERIC, "impossible to reach"); /* some compilers require default to do something */ + } } + + /* result */ +@@ -2102,8 +2246,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB + if ((ip==istart) && (op==ostart)) { /* no forward progress */ + zds->noForwardProgress ++; + if (zds->noForwardProgress >= ZSTD_NO_FORWARD_PROGRESS_MAX) { +- RETURN_ERROR_IF(op==oend, dstSize_tooSmall, ""); +- RETURN_ERROR_IF(ip==iend, srcSize_wrong, ""); ++ RETURN_ERROR_IF(op==oend, noForwardProgress_destFull, ""); ++ RETURN_ERROR_IF(ip==iend, noForwardProgress_inputEmpty, ""); + assert(0); + } + } else { +@@ -2140,11 +2284,17 @@ size_t ZSTD_decompressStream_simpleArgs ( + void* dst, size_t dstCapacity, size_t* dstPos, + const void* src, size_t srcSize, size_t* srcPos) + { +- ZSTD_outBuffer output = { dst, dstCapacity, *dstPos }; +- ZSTD_inBuffer input = { src, srcSize, *srcPos }; +- /* ZSTD_compress_generic() will check validity of dstPos and srcPos */ +- size_t const cErr = ZSTD_decompressStream(dctx, &output, &input); +- *dstPos = output.pos; +- *srcPos = input.pos; +- return cErr; ++ ZSTD_outBuffer output; ++ ZSTD_inBuffer input; ++ output.dst = dst; ++ output.size = dstCapacity; ++ output.pos = *dstPos; ++ input.src = src; ++ input.size = srcSize; ++ input.pos = *srcPos; ++ { size_t const cErr = ZSTD_decompressStream(dctx, &output, &input); ++ *dstPos = output.pos; ++ *srcPos = input.pos; ++ return cErr; ++ } + } +diff --git a/lib/zstd/decompress/zstd_decompress_block.c b/lib/zstd/decompress/zstd_decompress_block.c +index c1913b8e7c89..9fe9a12c8a2c 100644 +--- a/lib/zstd/decompress/zstd_decompress_block.c ++++ b/lib/zstd/decompress/zstd_decompress_block.c +@@ -1,5 +1,6 @@ ++// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -20,12 +21,12 @@ + #include "../common/mem.h" /* low level memory routines */ + #define FSE_STATIC_LINKING_ONLY + #include "../common/fse.h" +-#define HUF_STATIC_LINKING_ONLY + #include "../common/huf.h" + #include "../common/zstd_internal.h" + #include "zstd_decompress_internal.h" /* ZSTD_DCtx */ + #include "zstd_ddict.h" /* ZSTD_DDictDictContent */ + #include "zstd_decompress_block.h" ++#include "../common/bits.h" /* ZSTD_highbit32 */ + + /*_******************************************************* + * Macros +@@ -51,6 +52,13 @@ static void ZSTD_copy4(void* dst, const void* src) { ZSTD_memcpy(dst, src, 4); } + * Block decoding + ***************************************************************/ + ++static size_t ZSTD_blockSizeMax(ZSTD_DCtx const* dctx) ++{ ++ size_t const blockSizeMax = dctx->isFrameDecompression ? dctx->fParams.blockSizeMax : ZSTD_BLOCKSIZE_MAX; ++ assert(blockSizeMax <= ZSTD_BLOCKSIZE_MAX); ++ return blockSizeMax; ++} ++ + /*! ZSTD_getcBlockSize() : + * Provides the size of compressed block from block header `src` */ + size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, +@@ -73,41 +81,49 @@ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, + static void ZSTD_allocateLiteralsBuffer(ZSTD_DCtx* dctx, void* const dst, const size_t dstCapacity, const size_t litSize, + const streaming_operation streaming, const size_t expectedWriteSize, const unsigned splitImmediately) + { +- if (streaming == not_streaming && dstCapacity > ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH + litSize + WILDCOPY_OVERLENGTH) +- { +- /* room for litbuffer to fit without read faulting */ +- dctx->litBuffer = (BYTE*)dst + ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH; ++ size_t const blockSizeMax = ZSTD_blockSizeMax(dctx); ++ assert(litSize <= blockSizeMax); ++ assert(dctx->isFrameDecompression || streaming == not_streaming); ++ assert(expectedWriteSize <= blockSizeMax); ++ if (streaming == not_streaming && dstCapacity > blockSizeMax + WILDCOPY_OVERLENGTH + litSize + WILDCOPY_OVERLENGTH) { ++ /* If we aren't streaming, we can just put the literals after the output ++ * of the current block. We don't need to worry about overwriting the ++ * extDict of our window, because it doesn't exist. ++ * So if we have space after the end of the block, just put it there. ++ */ ++ dctx->litBuffer = (BYTE*)dst + blockSizeMax + WILDCOPY_OVERLENGTH; + dctx->litBufferEnd = dctx->litBuffer + litSize; + dctx->litBufferLocation = ZSTD_in_dst; +- } +- else if (litSize > ZSTD_LITBUFFEREXTRASIZE) +- { +- /* won't fit in litExtraBuffer, so it will be split between end of dst and extra buffer */ ++ } else if (litSize <= ZSTD_LITBUFFEREXTRASIZE) { ++ /* Literals fit entirely within the extra buffer, put them there to avoid ++ * having to split the literals. ++ */ ++ dctx->litBuffer = dctx->litExtraBuffer; ++ dctx->litBufferEnd = dctx->litBuffer + litSize; ++ dctx->litBufferLocation = ZSTD_not_in_dst; ++ } else { ++ assert(blockSizeMax > ZSTD_LITBUFFEREXTRASIZE); ++ /* Literals must be split between the output block and the extra lit ++ * buffer. We fill the extra lit buffer with the tail of the literals, ++ * and put the rest of the literals at the end of the block, with ++ * WILDCOPY_OVERLENGTH of buffer room to allow for overreads. ++ * This MUST not write more than our maxBlockSize beyond dst, because in ++ * streaming mode, that could overwrite part of our extDict window. ++ */ + if (splitImmediately) { + /* won't fit in litExtraBuffer, so it will be split between end of dst and extra buffer */ + dctx->litBuffer = (BYTE*)dst + expectedWriteSize - litSize + ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH; + dctx->litBufferEnd = dctx->litBuffer + litSize - ZSTD_LITBUFFEREXTRASIZE; +- } +- else { +- /* initially this will be stored entirely in dst during huffman decoding, it will partially shifted to litExtraBuffer after */ ++ } else { ++ /* initially this will be stored entirely in dst during huffman decoding, it will partially be shifted to litExtraBuffer after */ + dctx->litBuffer = (BYTE*)dst + expectedWriteSize - litSize; + dctx->litBufferEnd = (BYTE*)dst + expectedWriteSize; + } + dctx->litBufferLocation = ZSTD_split; +- } +- else +- { +- /* fits entirely within litExtraBuffer, so no split is necessary */ +- dctx->litBuffer = dctx->litExtraBuffer; +- dctx->litBufferEnd = dctx->litBuffer + litSize; +- dctx->litBufferLocation = ZSTD_not_in_dst; ++ assert(dctx->litBufferEnd <= (BYTE*)dst + expectedWriteSize); + } + } + +-/* Hidden declaration for fullbench */ +-size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, +- const void* src, size_t srcSize, +- void* dst, size_t dstCapacity, const streaming_operation streaming); + /*! ZSTD_decodeLiteralsBlock() : + * Where it is possible to do so without being stomped by the output during decompression, the literals block will be stored + * in the dstBuffer. If there is room to do so, it will be stored in full in the excess dst space after where the current +@@ -116,7 +132,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + * + * @return : nb of bytes read from src (< srcSize ) + * note : symbol not declared but exposed for fullbench */ +-size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, ++static size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + const void* src, size_t srcSize, /* note : srcSize < BLOCKSIZE */ + void* dst, size_t dstCapacity, const streaming_operation streaming) + { +@@ -125,6 +141,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + + { const BYTE* const istart = (const BYTE*) src; + symbolEncodingType_e const litEncType = (symbolEncodingType_e)(istart[0] & 3); ++ size_t const blockSizeMax = ZSTD_blockSizeMax(dctx); + + switch(litEncType) + { +@@ -134,13 +151,16 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + ZSTD_FALLTHROUGH; + + case set_compressed: +- RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need up to 5 for case 3"); ++ RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need up to 5 for case 3"); + { size_t lhSize, litSize, litCSize; + U32 singleStream=0; + U32 const lhlCode = (istart[0] >> 2) & 3; + U32 const lhc = MEM_readLE32(istart); + size_t hufSuccess; +- size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); ++ size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); ++ int const flags = 0 ++ | (ZSTD_DCtx_get_bmi2(dctx) ? HUF_flags_bmi2 : 0) ++ | (dctx->disableHufAsm ? HUF_flags_disableAsm : 0); + switch(lhlCode) + { + case 0: case 1: default: /* note : default is impossible, since lhlCode into [0..3] */ +@@ -164,7 +184,11 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + break; + } + RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); +- RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected, ""); ++ RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); ++ if (!singleStream) ++ RETURN_ERROR_IF(litSize < MIN_LITERALS_FOR_4_STREAMS, literals_headerWrong, ++ "Not enough literals (%zu) for the 4-streams mode (min %u)", ++ litSize, MIN_LITERALS_FOR_4_STREAMS); + RETURN_ERROR_IF(litCSize + lhSize > srcSize, corruption_detected, ""); + RETURN_ERROR_IF(expectedWriteSize < litSize , dstSize_tooSmall, ""); + ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 0); +@@ -176,13 +200,14 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + + if (litEncType==set_repeat) { + if (singleStream) { +- hufSuccess = HUF_decompress1X_usingDTable_bmi2( ++ hufSuccess = HUF_decompress1X_usingDTable( + dctx->litBuffer, litSize, istart+lhSize, litCSize, +- dctx->HUFptr, ZSTD_DCtx_get_bmi2(dctx)); ++ dctx->HUFptr, flags); + } else { +- hufSuccess = HUF_decompress4X_usingDTable_bmi2( ++ assert(litSize >= MIN_LITERALS_FOR_4_STREAMS); ++ hufSuccess = HUF_decompress4X_usingDTable( + dctx->litBuffer, litSize, istart+lhSize, litCSize, +- dctx->HUFptr, ZSTD_DCtx_get_bmi2(dctx)); ++ dctx->HUFptr, flags); + } + } else { + if (singleStream) { +@@ -190,26 +215,28 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + hufSuccess = HUF_decompress1X_DCtx_wksp( + dctx->entropy.hufTable, dctx->litBuffer, litSize, + istart+lhSize, litCSize, dctx->workspace, +- sizeof(dctx->workspace)); ++ sizeof(dctx->workspace), flags); + #else +- hufSuccess = HUF_decompress1X1_DCtx_wksp_bmi2( ++ hufSuccess = HUF_decompress1X1_DCtx_wksp( + dctx->entropy.hufTable, dctx->litBuffer, litSize, + istart+lhSize, litCSize, dctx->workspace, +- sizeof(dctx->workspace), ZSTD_DCtx_get_bmi2(dctx)); ++ sizeof(dctx->workspace), flags); + #endif + } else { +- hufSuccess = HUF_decompress4X_hufOnly_wksp_bmi2( ++ hufSuccess = HUF_decompress4X_hufOnly_wksp( + dctx->entropy.hufTable, dctx->litBuffer, litSize, + istart+lhSize, litCSize, dctx->workspace, +- sizeof(dctx->workspace), ZSTD_DCtx_get_bmi2(dctx)); ++ sizeof(dctx->workspace), flags); + } + } + if (dctx->litBufferLocation == ZSTD_split) + { ++ assert(litSize > ZSTD_LITBUFFEREXTRASIZE); + ZSTD_memcpy(dctx->litExtraBuffer, dctx->litBufferEnd - ZSTD_LITBUFFEREXTRASIZE, ZSTD_LITBUFFEREXTRASIZE); + ZSTD_memmove(dctx->litBuffer + ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH, dctx->litBuffer, litSize - ZSTD_LITBUFFEREXTRASIZE); + dctx->litBuffer += ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH; + dctx->litBufferEnd -= WILDCOPY_OVERLENGTH; ++ assert(dctx->litBufferEnd <= (BYTE*)dst + blockSizeMax); + } + + RETURN_ERROR_IF(HUF_isError(hufSuccess), corruption_detected, ""); +@@ -224,7 +251,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + case set_basic: + { size_t litSize, lhSize; + U32 const lhlCode = ((istart[0]) >> 2) & 3; +- size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); ++ size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); + switch(lhlCode) + { + case 0: case 2: default: /* note : default is impossible, since lhlCode into [0..3] */ +@@ -237,11 +264,13 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + break; + case 3: + lhSize = 3; ++ RETURN_ERROR_IF(srcSize<3, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize = 3"); + litSize = MEM_readLE24(istart) >> 4; + break; + } + + RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); ++ RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); + RETURN_ERROR_IF(expectedWriteSize < litSize, dstSize_tooSmall, ""); + ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 1); + if (lhSize+litSize+WILDCOPY_OVERLENGTH > srcSize) { /* risk reading beyond src buffer with wildcopy */ +@@ -270,7 +299,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + case set_rle: + { U32 const lhlCode = ((istart[0]) >> 2) & 3; + size_t litSize, lhSize; +- size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); ++ size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); + switch(lhlCode) + { + case 0: case 2: default: /* note : default is impossible, since lhlCode into [0..3] */ +@@ -279,16 +308,17 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + break; + case 1: + lhSize = 2; ++ RETURN_ERROR_IF(srcSize<3, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize+1 = 3"); + litSize = MEM_readLE16(istart) >> 4; + break; + case 3: + lhSize = 3; ++ RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize+1 = 4"); + litSize = MEM_readLE24(istart) >> 4; +- RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need lhSize+1 = 4"); + break; + } + RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); +- RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected, ""); ++ RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); + RETURN_ERROR_IF(expectedWriteSize < litSize, dstSize_tooSmall, ""); + ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 1); + if (dctx->litBufferLocation == ZSTD_split) +@@ -310,6 +340,18 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, + } + } + ++/* Hidden declaration for fullbench */ ++size_t ZSTD_decodeLiteralsBlock_wrapper(ZSTD_DCtx* dctx, ++ const void* src, size_t srcSize, ++ void* dst, size_t dstCapacity); ++size_t ZSTD_decodeLiteralsBlock_wrapper(ZSTD_DCtx* dctx, ++ const void* src, size_t srcSize, ++ void* dst, size_t dstCapacity) ++{ ++ dctx->isFrameDecompression = 0; ++ return ZSTD_decodeLiteralsBlock(dctx, src, srcSize, dst, dstCapacity, not_streaming); ++} ++ + /* Default FSE distribution tables. + * These are pre-calculated FSE decoding tables using default distributions as defined in specification : + * https://github.com/facebook/zstd/blob/release/doc/zstd_compression_format.md#default-distributions +@@ -506,14 +548,15 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, + for (i = 8; i < n; i += 8) { + MEM_write64(spread + pos + i, sv); + } +- pos += n; ++ assert(n>=0); ++ pos += (size_t)n; + } + } + /* Now we spread those positions across the table. +- * The benefit of doing it in two stages is that we avoid the the ++ * The benefit of doing it in two stages is that we avoid the + * variable size inner loop, which caused lots of branch misses. + * Now we can run through all the positions without any branch misses. +- * We unroll the loop twice, since that is what emperically worked best. ++ * We unroll the loop twice, since that is what empirically worked best. + */ + { + size_t position = 0; +@@ -540,7 +583,7 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, + for (i=0; i highThreshold) position = (position + step) & tableMask; /* lowprob area */ ++ while (UNLIKELY(position > highThreshold)) position = (position + step) & tableMask; /* lowprob area */ + } } + assert(position == 0); /* position must reach all cells once, otherwise normalizedCounter is incorrect */ + } +@@ -551,7 +594,7 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, + for (u=0; u 0x7F) { + if (nbSeq == 0xFF) { + RETURN_ERROR_IF(ip+2 > iend, srcSize_wrong, ""); +@@ -681,8 +719,16 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, + } + *nbSeqPtr = nbSeq; + ++ if (nbSeq == 0) { ++ /* No sequence : section ends immediately */ ++ RETURN_ERROR_IF(ip != iend, corruption_detected, ++ "extraneous data present in the Sequences section"); ++ return (size_t)(ip - istart); ++ } ++ + /* FSE table descriptors */ + RETURN_ERROR_IF(ip+1 > iend, srcSize_wrong, ""); /* minimum possible size: 1 byte for symbol encoding types */ ++ RETURN_ERROR_IF(*ip & 3, corruption_detected, ""); /* The last field, Reserved, must be all-zeroes. */ + { symbolEncodingType_e const LLtype = (symbolEncodingType_e)(*ip >> 6); + symbolEncodingType_e const OFtype = (symbolEncodingType_e)((*ip >> 4) & 3); + symbolEncodingType_e const MLtype = (symbolEncodingType_e)((*ip >> 2) & 3); +@@ -829,7 +875,7 @@ static void ZSTD_safecopy(BYTE* op, const BYTE* const oend_w, BYTE const* ip, pt + /* ZSTD_safecopyDstBeforeSrc(): + * This version allows overlap with dst before src, or handles the non-overlap case with dst after src + * Kept separate from more common ZSTD_safecopy case to avoid performance impact to the safecopy common case */ +-static void ZSTD_safecopyDstBeforeSrc(BYTE* op, BYTE const* ip, ptrdiff_t length) { ++static void ZSTD_safecopyDstBeforeSrc(BYTE* op, const BYTE* ip, ptrdiff_t length) { + ptrdiff_t const diff = op - ip; + BYTE* const oend = op + length; + +@@ -858,6 +904,7 @@ static void ZSTD_safecopyDstBeforeSrc(BYTE* op, BYTE const* ip, ptrdiff_t length + * to be optimized for many small sequences, since those fall into ZSTD_execSequence(). + */ + FORCE_NOINLINE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_execSequenceEnd(BYTE* op, + BYTE* const oend, seq_t sequence, + const BYTE** litPtr, const BYTE* const litLimit, +@@ -905,6 +952,7 @@ size_t ZSTD_execSequenceEnd(BYTE* op, + * This version is intended to be used during instances where the litBuffer is still split. It is kept separate to avoid performance impact for the good case. + */ + FORCE_NOINLINE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_execSequenceEndSplitLitBuffer(BYTE* op, + BYTE* const oend, const BYTE* const oend_w, seq_t sequence, + const BYTE** litPtr, const BYTE* const litLimit, +@@ -950,6 +998,7 @@ size_t ZSTD_execSequenceEndSplitLitBuffer(BYTE* op, + } + + HINT_INLINE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_execSequence(BYTE* op, + BYTE* const oend, seq_t sequence, + const BYTE** litPtr, const BYTE* const litLimit, +@@ -964,6 +1013,11 @@ size_t ZSTD_execSequence(BYTE* op, + + assert(op != NULL /* Precondition */); + assert(oend_w < oend /* No underflow */); ++ ++#if defined(__aarch64__) ++ /* prefetch sequence starting from match that will be used for copy later */ ++ PREFETCH_L1(match); ++#endif + /* Handle edge cases in a slow path: + * - Read beyond end of literals + * - Match end is within WILDCOPY_OVERLIMIT of oend +@@ -1043,6 +1097,7 @@ size_t ZSTD_execSequence(BYTE* op, + } + + HINT_INLINE ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + size_t ZSTD_execSequenceSplitLitBuffer(BYTE* op, + BYTE* const oend, const BYTE* const oend_w, seq_t sequence, + const BYTE** litPtr, const BYTE* const litLimit, +@@ -1154,7 +1209,7 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 + } + + /* We need to add at most (ZSTD_WINDOWLOG_MAX_32 - 1) bits to read the maximum +- * offset bits. But we can only read at most (STREAM_ACCUMULATOR_MIN_32 - 1) ++ * offset bits. But we can only read at most STREAM_ACCUMULATOR_MIN_32 + * bits before reloading. This value is the maximum number of bytes we read + * after reloading when we are decoding long offsets. + */ +@@ -1165,13 +1220,37 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 + + typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e; + ++/* ++ * ZSTD_decodeSequence(): ++ * @p longOffsets : tells the decoder to reload more bit while decoding large offsets ++ * only used in 32-bit mode ++ * @return : Sequence (litL + matchL + offset) ++ */ + FORCE_INLINE_TEMPLATE seq_t +-ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) ++ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const int isLastSeq) + { + seq_t seq; ++ /* ++ * ZSTD_seqSymbol is a 64 bits wide structure. ++ * It can be loaded in one operation ++ * and its fields extracted by simply shifting or bit-extracting on aarch64. ++ * GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh ++ * operations that cause performance drop. This can be avoided by using this ++ * ZSTD_memcpy hack. ++ */ ++#if defined(__aarch64__) && (defined(__GNUC__) && !defined(__clang__)) ++ ZSTD_seqSymbol llDInfoS, mlDInfoS, ofDInfoS; ++ ZSTD_seqSymbol* const llDInfo = &llDInfoS; ++ ZSTD_seqSymbol* const mlDInfo = &mlDInfoS; ++ ZSTD_seqSymbol* const ofDInfo = &ofDInfoS; ++ ZSTD_memcpy(llDInfo, seqState->stateLL.table + seqState->stateLL.state, sizeof(ZSTD_seqSymbol)); ++ ZSTD_memcpy(mlDInfo, seqState->stateML.table + seqState->stateML.state, sizeof(ZSTD_seqSymbol)); ++ ZSTD_memcpy(ofDInfo, seqState->stateOffb.table + seqState->stateOffb.state, sizeof(ZSTD_seqSymbol)); ++#else + const ZSTD_seqSymbol* const llDInfo = seqState->stateLL.table + seqState->stateLL.state; + const ZSTD_seqSymbol* const mlDInfo = seqState->stateML.table + seqState->stateML.state; + const ZSTD_seqSymbol* const ofDInfo = seqState->stateOffb.table + seqState->stateOffb.state; ++#endif + seq.matchLength = mlDInfo->baseValue; + seq.litLength = llDInfo->baseValue; + { U32 const ofBase = ofDInfo->baseValue; +@@ -1186,28 +1265,31 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) + U32 const llnbBits = llDInfo->nbBits; + U32 const mlnbBits = mlDInfo->nbBits; + U32 const ofnbBits = ofDInfo->nbBits; ++ ++ assert(llBits <= MaxLLBits); ++ assert(mlBits <= MaxMLBits); ++ assert(ofBits <= MaxOff); + /* + * As gcc has better branch and block analyzers, sometimes it is only +- * valuable to mark likelyness for clang, it gives around 3-4% of ++ * valuable to mark likeliness for clang, it gives around 3-4% of + * performance. + */ + + /* sequence */ + { size_t offset; +- #if defined(__clang__) +- if (LIKELY(ofBits > 1)) { +- #else + if (ofBits > 1) { +- #endif + ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1); + ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5); +- assert(ofBits <= MaxOff); ++ ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 > LONG_OFFSETS_MAX_EXTRA_BITS_32); ++ ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 - LONG_OFFSETS_MAX_EXTRA_BITS_32 >= MaxMLBits); + if (MEM_32bits() && longOffsets && (ofBits >= STREAM_ACCUMULATOR_MIN_32)) { +- U32 const extraBits = ofBits - MIN(ofBits, 32 - seqState->DStream.bitsConsumed); ++ /* Always read extra bits, this keeps the logic simple, ++ * avoids branches, and avoids accidentally reading 0 bits. ++ */ ++ U32 const extraBits = LONG_OFFSETS_MAX_EXTRA_BITS_32; + offset = ofBase + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits); + BIT_reloadDStream(&seqState->DStream); +- if (extraBits) offset += BIT_readBitsFast(&seqState->DStream, extraBits); +- assert(extraBits <= LONG_OFFSETS_MAX_EXTRA_BITS_32); /* to avoid another reload */ ++ offset += BIT_readBitsFast(&seqState->DStream, extraBits); + } else { + offset = ofBase + BIT_readBitsFast(&seqState->DStream, ofBits/*>0*/); /* <= (ZSTD_WINDOWLOG_MAX-1) bits */ + if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); +@@ -1224,7 +1306,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) + } else { + offset = ofBase + ll0 + BIT_readBitsFast(&seqState->DStream, 1); + { size_t temp = (offset==3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset]; +- temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */ ++ temp -= !temp; /* 0 is not valid: input corrupted => force offset to -1 => corruption detected at execSequence */ + if (offset != 1) seqState->prevOffset[2] = seqState->prevOffset[1]; + seqState->prevOffset[1] = seqState->prevOffset[0]; + seqState->prevOffset[0] = offset = temp; +@@ -1232,11 +1314,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) + seq.offset = offset; + } + +- #if defined(__clang__) +- if (UNLIKELY(mlBits > 0)) +- #else + if (mlBits > 0) +- #endif + seq.matchLength += BIT_readBitsFast(&seqState->DStream, mlBits/*>0*/); + + if (MEM_32bits() && (mlBits+llBits >= STREAM_ACCUMULATOR_MIN_32-LONG_OFFSETS_MAX_EXTRA_BITS_32)) +@@ -1246,11 +1324,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) + /* Ensure there are enough bits to read the rest of data in 64-bit mode. */ + ZSTD_STATIC_ASSERT(16+LLFSELog+MLFSELog+OffFSELog < STREAM_ACCUMULATOR_MIN_64); + +- #if defined(__clang__) +- if (UNLIKELY(llBits > 0)) +- #else + if (llBits > 0) +- #endif + seq.litLength += BIT_readBitsFast(&seqState->DStream, llBits/*>0*/); + + if (MEM_32bits()) +@@ -1259,17 +1333,22 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) + DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u", + (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); + +- ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */ +- ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */ +- if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ +- ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */ ++ if (!isLastSeq) { ++ /* don't update FSE state for last Sequence */ ++ ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */ ++ ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */ ++ if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ ++ ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */ ++ BIT_reloadDStream(&seqState->DStream); ++ } + } + + return seq; + } + +-#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION +-MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd) ++#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) ++#if DEBUGLEVEL >= 1 ++static int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd) + { + size_t const windowSize = dctx->fParams.windowSize; + /* No dictionary used. */ +@@ -1283,30 +1362,33 @@ MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefix + /* Dictionary is active. */ + return 1; + } ++#endif + +-MEM_STATIC void ZSTD_assertValidSequence( ++static void ZSTD_assertValidSequence( + ZSTD_DCtx const* dctx, + BYTE const* op, BYTE const* oend, + seq_t const seq, + BYTE const* prefixStart, BYTE const* virtualStart) + { + #if DEBUGLEVEL >= 1 +- size_t const windowSize = dctx->fParams.windowSize; +- size_t const sequenceSize = seq.litLength + seq.matchLength; +- BYTE const* const oLitEnd = op + seq.litLength; +- DEBUGLOG(6, "Checking sequence: litL=%u matchL=%u offset=%u", +- (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); +- assert(op <= oend); +- assert((size_t)(oend - op) >= sequenceSize); +- assert(sequenceSize <= ZSTD_BLOCKSIZE_MAX); +- if (ZSTD_dictionaryIsActive(dctx, prefixStart, oLitEnd)) { +- size_t const dictSize = (size_t)((char const*)dctx->dictContentEndForFuzzing - (char const*)dctx->dictContentBeginForFuzzing); +- /* Offset must be within the dictionary. */ +- assert(seq.offset <= (size_t)(oLitEnd - virtualStart)); +- assert(seq.offset <= windowSize + dictSize); +- } else { +- /* Offset must be within our window. */ +- assert(seq.offset <= windowSize); ++ if (dctx->isFrameDecompression) { ++ size_t const windowSize = dctx->fParams.windowSize; ++ size_t const sequenceSize = seq.litLength + seq.matchLength; ++ BYTE const* const oLitEnd = op + seq.litLength; ++ DEBUGLOG(6, "Checking sequence: litL=%u matchL=%u offset=%u", ++ (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); ++ assert(op <= oend); ++ assert((size_t)(oend - op) >= sequenceSize); ++ assert(sequenceSize <= ZSTD_blockSizeMax(dctx)); ++ if (ZSTD_dictionaryIsActive(dctx, prefixStart, oLitEnd)) { ++ size_t const dictSize = (size_t)((char const*)dctx->dictContentEndForFuzzing - (char const*)dctx->dictContentBeginForFuzzing); ++ /* Offset must be within the dictionary. */ ++ assert(seq.offset <= (size_t)(oLitEnd - virtualStart)); ++ assert(seq.offset <= windowSize + dictSize); ++ } else { ++ /* Offset must be within our window. */ ++ assert(seq.offset <= windowSize); ++ } + } + #else + (void)dctx, (void)op, (void)oend, (void)seq, (void)prefixStart, (void)virtualStart; +@@ -1322,23 +1404,21 @@ DONT_VECTORIZE + ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { + const BYTE* ip = (const BYTE*)seqStart; + const BYTE* const iend = ip + seqSize; + BYTE* const ostart = (BYTE*)dst; +- BYTE* const oend = ostart + maxDstSize; ++ BYTE* const oend = ZSTD_maybeNullPtrAdd(ostart, maxDstSize); + BYTE* op = ostart; + const BYTE* litPtr = dctx->litPtr; + const BYTE* litBufferEnd = dctx->litBufferEnd; + const BYTE* const prefixStart = (const BYTE*) (dctx->prefixStart); + const BYTE* const vBase = (const BYTE*) (dctx->virtualStart); + const BYTE* const dictEnd = (const BYTE*) (dctx->dictEnd); +- DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer"); +- (void)frame; ++ DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer (%i seqs)", nbSeq); + +- /* Regen sequences */ ++ /* Literals are split between internal buffer & output buffer */ + if (nbSeq) { + seqState_t seqState; + dctx->fseEntropy = 1; +@@ -1357,8 +1437,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, + BIT_DStream_completed < BIT_DStream_overflow); + + /* decompress without overrunning litPtr begins */ +- { +- seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset); ++ { seq_t sequence = {0,0,0}; /* some static analyzer believe that @sequence is not initialized (it necessarily is, since for(;;) loop as at least one iteration) */ + /* Align the decompression loop to 32 + 16 bytes. + * + * zstd compiled with gcc-9 on an Intel i9-9900k shows 10% decompression +@@ -1420,27 +1499,26 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, + #endif + + /* Handle the initial state where litBuffer is currently split between dst and litExtraBuffer */ +- for (; litPtr + sequence.litLength <= dctx->litBufferEnd; ) { +- size_t const oneSeqSize = ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequence.litLength - WILDCOPY_OVERLENGTH, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); ++ for ( ; nbSeq; nbSeq--) { ++ sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); ++ if (litPtr + sequence.litLength > dctx->litBufferEnd) break; ++ { size_t const oneSeqSize = ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequence.litLength - WILDCOPY_OVERLENGTH, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) +- assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); ++ assert(!ZSTD_isError(oneSeqSize)); ++ ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + #endif +- if (UNLIKELY(ZSTD_isError(oneSeqSize))) +- return oneSeqSize; +- DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); +- op += oneSeqSize; +- if (UNLIKELY(!--nbSeq)) +- break; +- BIT_reloadDStream(&(seqState.DStream)); +- sequence = ZSTD_decodeSequence(&seqState, isLongOffset); +- } ++ if (UNLIKELY(ZSTD_isError(oneSeqSize))) ++ return oneSeqSize; ++ DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); ++ op += oneSeqSize; ++ } } ++ DEBUGLOG(6, "reached: (litPtr + sequence.litLength > dctx->litBufferEnd)"); + + /* If there are more sequences, they will need to read literals from litExtraBuffer; copy over the remainder from dst and update litPtr and litEnd */ + if (nbSeq > 0) { + const size_t leftoverLit = dctx->litBufferEnd - litPtr; +- if (leftoverLit) +- { ++ DEBUGLOG(6, "There are %i sequences left, and %zu/%zu literals left in buffer", nbSeq, leftoverLit, sequence.litLength); ++ if (leftoverLit) { + RETURN_ERROR_IF(leftoverLit > (size_t)(oend - op), dstSize_tooSmall, "remaining lit must fit within dstBuffer"); + ZSTD_safecopyDstBeforeSrc(op, litPtr, leftoverLit); + sequence.litLength -= leftoverLit; +@@ -1449,24 +1527,22 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, + litPtr = dctx->litExtraBuffer; + litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; + dctx->litBufferLocation = ZSTD_not_in_dst; +- { +- size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); ++ { size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) + assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); ++ ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + #endif + if (UNLIKELY(ZSTD_isError(oneSeqSize))) + return oneSeqSize; + DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); + op += oneSeqSize; +- if (--nbSeq) +- BIT_reloadDStream(&(seqState.DStream)); + } ++ nbSeq--; + } + } + +- if (nbSeq > 0) /* there is remaining lit from extra buffer */ +- { ++ if (nbSeq > 0) { ++ /* there is remaining lit from extra buffer */ + + #if defined(__x86_64__) + __asm__(".p2align 6"); +@@ -1485,35 +1561,34 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, + # endif + #endif + +- for (; ; ) { +- seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset); ++ for ( ; nbSeq ; nbSeq--) { ++ seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); + size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) + assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); ++ ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + #endif + if (UNLIKELY(ZSTD_isError(oneSeqSize))) + return oneSeqSize; + DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); + op += oneSeqSize; +- if (UNLIKELY(!--nbSeq)) +- break; +- BIT_reloadDStream(&(seqState.DStream)); + } + } + + /* check if reached exact end */ + DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer: after decode loop, remaining nbSeq : %i", nbSeq); + RETURN_ERROR_IF(nbSeq, corruption_detected, ""); +- RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, ""); ++ DEBUGLOG(5, "bitStream : start=%p, ptr=%p, bitsConsumed=%u", seqState.DStream.start, seqState.DStream.ptr, seqState.DStream.bitsConsumed); ++ RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, ""); + /* save reps for next block */ + { U32 i; for (i=0; ientropy.rep[i] = (U32)(seqState.prevOffset[i]); } + } + + /* last literal segment */ +- if (dctx->litBufferLocation == ZSTD_split) /* split hasn't been reached yet, first get dst then copy litExtraBuffer */ +- { +- size_t const lastLLSize = litBufferEnd - litPtr; ++ if (dctx->litBufferLocation == ZSTD_split) { ++ /* split hasn't been reached yet, first get dst then copy litExtraBuffer */ ++ size_t const lastLLSize = (size_t)(litBufferEnd - litPtr); ++ DEBUGLOG(6, "copy last literals from segment : %u", (U32)lastLLSize); + RETURN_ERROR_IF(lastLLSize > (size_t)(oend - op), dstSize_tooSmall, ""); + if (op != NULL) { + ZSTD_memmove(op, litPtr, lastLLSize); +@@ -1523,15 +1598,17 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, + litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; + dctx->litBufferLocation = ZSTD_not_in_dst; + } +- { size_t const lastLLSize = litBufferEnd - litPtr; ++ /* copy last literals from internal buffer */ ++ { size_t const lastLLSize = (size_t)(litBufferEnd - litPtr); ++ DEBUGLOG(6, "copy last literals from internal buffer : %u", (U32)lastLLSize); + RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall, ""); + if (op != NULL) { + ZSTD_memcpy(op, litPtr, lastLLSize); + op += lastLLSize; +- } +- } ++ } } + +- return op-ostart; ++ DEBUGLOG(6, "decoded block of size %u bytes", (U32)(op - ostart)); ++ return (size_t)(op - ostart); + } + + FORCE_INLINE_TEMPLATE size_t +@@ -1539,21 +1616,19 @@ DONT_VECTORIZE + ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { + const BYTE* ip = (const BYTE*)seqStart; + const BYTE* const iend = ip + seqSize; + BYTE* const ostart = (BYTE*)dst; +- BYTE* const oend = dctx->litBufferLocation == ZSTD_not_in_dst ? ostart + maxDstSize : dctx->litBuffer; ++ BYTE* const oend = dctx->litBufferLocation == ZSTD_not_in_dst ? ZSTD_maybeNullPtrAdd(ostart, maxDstSize) : dctx->litBuffer; + BYTE* op = ostart; + const BYTE* litPtr = dctx->litPtr; + const BYTE* const litEnd = litPtr + dctx->litSize; + const BYTE* const prefixStart = (const BYTE*)(dctx->prefixStart); + const BYTE* const vBase = (const BYTE*)(dctx->virtualStart); + const BYTE* const dictEnd = (const BYTE*)(dctx->dictEnd); +- DEBUGLOG(5, "ZSTD_decompressSequences_body"); +- (void)frame; ++ DEBUGLOG(5, "ZSTD_decompressSequences_body: nbSeq = %d", nbSeq); + + /* Regen sequences */ + if (nbSeq) { +@@ -1568,11 +1643,6 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, + ZSTD_initFseState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); + assert(dst != NULL); + +- ZSTD_STATIC_ASSERT( +- BIT_DStream_unfinished < BIT_DStream_completed && +- BIT_DStream_endOfBuffer < BIT_DStream_completed && +- BIT_DStream_completed < BIT_DStream_overflow); +- + #if defined(__x86_64__) + __asm__(".p2align 6"); + __asm__("nop"); +@@ -1587,73 +1657,70 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, + # endif + #endif + +- for ( ; ; ) { +- seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset); ++ for ( ; nbSeq ; nbSeq--) { ++ seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); + size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) + assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); ++ ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + #endif + if (UNLIKELY(ZSTD_isError(oneSeqSize))) + return oneSeqSize; + DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); + op += oneSeqSize; +- if (UNLIKELY(!--nbSeq)) +- break; +- BIT_reloadDStream(&(seqState.DStream)); + } + + /* check if reached exact end */ +- DEBUGLOG(5, "ZSTD_decompressSequences_body: after decode loop, remaining nbSeq : %i", nbSeq); +- RETURN_ERROR_IF(nbSeq, corruption_detected, ""); +- RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, ""); ++ assert(nbSeq == 0); ++ RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, ""); + /* save reps for next block */ + { U32 i; for (i=0; ientropy.rep[i] = (U32)(seqState.prevOffset[i]); } + } + + /* last literal segment */ +- { size_t const lastLLSize = litEnd - litPtr; ++ { size_t const lastLLSize = (size_t)(litEnd - litPtr); ++ DEBUGLOG(6, "copy last literals : %u", (U32)lastLLSize); + RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall, ""); + if (op != NULL) { + ZSTD_memcpy(op, litPtr, lastLLSize); + op += lastLLSize; +- } +- } ++ } } + +- return op-ostart; ++ DEBUGLOG(6, "decoded block of size %u bytes", (U32)(op - ostart)); ++ return (size_t)(op - ostart); + } + + static size_t + ZSTD_decompressSequences_default(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { +- return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + + static size_t + ZSTD_decompressSequencesSplitLitBuffer_default(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { +- return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ + + #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT + +-FORCE_INLINE_TEMPLATE size_t +-ZSTD_prefetchMatch(size_t prefetchPos, seq_t const sequence, ++FORCE_INLINE_TEMPLATE ++ ++size_t ZSTD_prefetchMatch(size_t prefetchPos, seq_t const sequence, + const BYTE* const prefixStart, const BYTE* const dictEnd) + { + prefetchPos += sequence.litLength; + { const BYTE* const matchBase = (sequence.offset > prefetchPos) ? dictEnd : prefixStart; +- const BYTE* const match = matchBase + prefetchPos - sequence.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted. +- * No consequence though : memory address is only used for prefetching, not for dereferencing */ ++ /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted. ++ * No consequence though : memory address is only used for prefetching, not for dereferencing */ ++ const BYTE* const match = ZSTD_wrappedPtrSub(ZSTD_wrappedPtrAdd(matchBase, prefetchPos), sequence.offset); + PREFETCH_L1(match); PREFETCH_L1(match+CACHELINE_SIZE); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */ + } + return prefetchPos + sequence.matchLength; +@@ -1668,20 +1735,18 @@ ZSTD_decompressSequencesLong_body( + ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { + const BYTE* ip = (const BYTE*)seqStart; + const BYTE* const iend = ip + seqSize; + BYTE* const ostart = (BYTE*)dst; +- BYTE* const oend = dctx->litBufferLocation == ZSTD_in_dst ? dctx->litBuffer : ostart + maxDstSize; ++ BYTE* const oend = dctx->litBufferLocation == ZSTD_in_dst ? dctx->litBuffer : ZSTD_maybeNullPtrAdd(ostart, maxDstSize); + BYTE* op = ostart; + const BYTE* litPtr = dctx->litPtr; + const BYTE* litBufferEnd = dctx->litBufferEnd; + const BYTE* const prefixStart = (const BYTE*) (dctx->prefixStart); + const BYTE* const dictStart = (const BYTE*) (dctx->virtualStart); + const BYTE* const dictEnd = (const BYTE*) (dctx->dictEnd); +- (void)frame; + + /* Regen sequences */ + if (nbSeq) { +@@ -1706,20 +1771,17 @@ ZSTD_decompressSequencesLong_body( + ZSTD_initFseState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); + + /* prepare in advance */ +- for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNblitBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd) +- { ++ if (dctx->litBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd) { + /* lit buffer is reaching split point, empty out the first buffer and transition to litExtraBuffer */ + const size_t leftoverLit = dctx->litBufferEnd - litPtr; + if (leftoverLit) +@@ -1732,26 +1794,26 @@ ZSTD_decompressSequencesLong_body( + litPtr = dctx->litExtraBuffer; + litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; + dctx->litBufferLocation = ZSTD_not_in_dst; +- oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); ++ { size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) +- assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); ++ assert(!ZSTD_isError(oneSeqSize)); ++ ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); + #endif +- if (ZSTD_isError(oneSeqSize)) return oneSeqSize; ++ if (ZSTD_isError(oneSeqSize)) return oneSeqSize; + +- prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd); +- sequences[seqNb & STORED_SEQS_MASK] = sequence; +- op += oneSeqSize; +- } ++ prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd); ++ sequences[seqNb & STORED_SEQS_MASK] = sequence; ++ op += oneSeqSize; ++ } } + else + { + /* lit buffer is either wholly contained in first or second split, or not split at all*/ +- oneSeqSize = dctx->litBufferLocation == ZSTD_split ? ++ size_t const oneSeqSize = dctx->litBufferLocation == ZSTD_split ? + ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength - WILDCOPY_OVERLENGTH, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd) : + ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) + assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); ++ ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); + #endif + if (ZSTD_isError(oneSeqSize)) return oneSeqSize; + +@@ -1760,17 +1822,15 @@ ZSTD_decompressSequencesLong_body( + op += oneSeqSize; + } + } +- RETURN_ERROR_IF(seqNblitBufferLocation == ZSTD_split && litPtr + sequence->litLength > dctx->litBufferEnd) +- { ++ if (dctx->litBufferLocation == ZSTD_split && litPtr + sequence->litLength > dctx->litBufferEnd) { + const size_t leftoverLit = dctx->litBufferEnd - litPtr; +- if (leftoverLit) +- { ++ if (leftoverLit) { + RETURN_ERROR_IF(leftoverLit > (size_t)(oend - op), dstSize_tooSmall, "remaining lit must fit within dstBuffer"); + ZSTD_safecopyDstBeforeSrc(op, litPtr, leftoverLit); + sequence->litLength -= leftoverLit; +@@ -1779,11 +1839,10 @@ ZSTD_decompressSequencesLong_body( + litPtr = dctx->litExtraBuffer; + litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; + dctx->litBufferLocation = ZSTD_not_in_dst; +- { +- size_t const oneSeqSize = ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); ++ { size_t const oneSeqSize = ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) + assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); ++ ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); + #endif + if (ZSTD_isError(oneSeqSize)) return oneSeqSize; + op += oneSeqSize; +@@ -1796,7 +1855,7 @@ ZSTD_decompressSequencesLong_body( + ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) + assert(!ZSTD_isError(oneSeqSize)); +- if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); ++ ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); + #endif + if (ZSTD_isError(oneSeqSize)) return oneSeqSize; + op += oneSeqSize; +@@ -1808,8 +1867,7 @@ ZSTD_decompressSequencesLong_body( + } + + /* last literal segment */ +- if (dctx->litBufferLocation == ZSTD_split) /* first deplete literal buffer in dst, then copy litExtraBuffer */ +- { ++ if (dctx->litBufferLocation == ZSTD_split) { /* first deplete literal buffer in dst, then copy litExtraBuffer */ + size_t const lastLLSize = litBufferEnd - litPtr; + RETURN_ERROR_IF(lastLLSize > (size_t)(oend - op), dstSize_tooSmall, ""); + if (op != NULL) { +@@ -1827,17 +1885,16 @@ ZSTD_decompressSequencesLong_body( + } + } + +- return op-ostart; ++ return (size_t)(op - ostart); + } + + static size_t + ZSTD_decompressSequencesLong_default(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { +- return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ + +@@ -1851,20 +1908,18 @@ DONT_VECTORIZE + ZSTD_decompressSequences_bmi2(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { +- return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + static BMI2_TARGET_ATTRIBUTE size_t + DONT_VECTORIZE + ZSTD_decompressSequencesSplitLitBuffer_bmi2(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { +- return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ + +@@ -1873,10 +1928,9 @@ static BMI2_TARGET_ATTRIBUTE size_t + ZSTD_decompressSequencesLong_bmi2(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { +- return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ + +@@ -1886,37 +1940,34 @@ typedef size_t (*ZSTD_decompressSequences_t)( + ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame); ++ const ZSTD_longOffset_e isLongOffset); + + #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG + static size_t + ZSTD_decompressSequences(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { + DEBUGLOG(5, "ZSTD_decompressSequences"); + #if DYNAMIC_BMI2 + if (ZSTD_DCtx_get_bmi2(dctx)) { +- return ZSTD_decompressSequences_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif +- return ZSTD_decompressSequences_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + static size_t + ZSTD_decompressSequencesSplitLitBuffer(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { + DEBUGLOG(5, "ZSTD_decompressSequencesSplitLitBuffer"); + #if DYNAMIC_BMI2 + if (ZSTD_DCtx_get_bmi2(dctx)) { +- return ZSTD_decompressSequencesSplitLitBuffer_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesSplitLitBuffer_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif +- return ZSTD_decompressSequencesSplitLitBuffer_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesSplitLitBuffer_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ + +@@ -1931,69 +1982,114 @@ static size_t + ZSTD_decompressSequencesLong(ZSTD_DCtx* dctx, + void* dst, size_t maxDstSize, + const void* seqStart, size_t seqSize, int nbSeq, +- const ZSTD_longOffset_e isLongOffset, +- const int frame) ++ const ZSTD_longOffset_e isLongOffset) + { + DEBUGLOG(5, "ZSTD_decompressSequencesLong"); + #if DYNAMIC_BMI2 + if (ZSTD_DCtx_get_bmi2(dctx)) { +- return ZSTD_decompressSequencesLong_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesLong_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif +- return ZSTD_decompressSequencesLong_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesLong_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); + } + #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ + + ++/* ++ * @returns The total size of the history referenceable by zstd, including ++ * both the prefix and the extDict. At @p op any offset larger than this ++ * is invalid. ++ */ ++static size_t ZSTD_totalHistorySize(BYTE* op, BYTE const* virtualStart) ++{ ++ return (size_t)(op - virtualStart); ++} ++ ++typedef struct { ++ unsigned longOffsetShare; ++ unsigned maxNbAdditionalBits; ++} ZSTD_OffsetInfo; + +-#if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ +- !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) +-/* ZSTD_getLongOffsetsShare() : ++/* ZSTD_getOffsetInfo() : + * condition : offTable must be valid + * @return : "share" of long offsets (arbitrarily defined as > (1<<23)) +- * compared to maximum possible of (1< 22) total += 1; ++ ZSTD_OffsetInfo info = {0, 0}; ++ /* If nbSeq == 0, then the offTable is uninitialized, but we have ++ * no sequences, so both values should be 0. ++ */ ++ if (nbSeq != 0) { ++ const void* ptr = offTable; ++ U32 const tableLog = ((const ZSTD_seqSymbol_header*)ptr)[0].tableLog; ++ const ZSTD_seqSymbol* table = offTable + 1; ++ U32 const max = 1 << tableLog; ++ U32 u; ++ DEBUGLOG(5, "ZSTD_getLongOffsetsShare: (tableLog=%u)", tableLog); ++ ++ assert(max <= (1 << OffFSELog)); /* max not too large */ ++ for (u=0; u 22) info.longOffsetShare += 1; ++ } ++ ++ assert(tableLog <= OffFSELog); ++ info.longOffsetShare <<= (OffFSELog - tableLog); /* scale to OffFSELog */ + } + +- assert(tableLog <= OffFSELog); +- total <<= (OffFSELog - tableLog); /* scale to OffFSELog */ ++ return info; ++} + +- return total; ++/* ++ * @returns The maximum offset we can decode in one read of our bitstream, without ++ * reloading more bits in the middle of the offset bits read. Any offsets larger ++ * than this must use the long offset decoder. ++ */ ++static size_t ZSTD_maxShortOffset(void) ++{ ++ if (MEM_64bits()) { ++ /* We can decode any offset without reloading bits. ++ * This might change if the max window size grows. ++ */ ++ ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX <= 31); ++ return (size_t)-1; ++ } else { ++ /* The maximum offBase is (1 << (STREAM_ACCUMULATOR_MIN + 1)) - 1. ++ * This offBase would require STREAM_ACCUMULATOR_MIN extra bits. ++ * Then we have to subtract ZSTD_REP_NUM to get the maximum possible offset. ++ */ ++ size_t const maxOffbase = ((size_t)1 << (STREAM_ACCUMULATOR_MIN + 1)) - 1; ++ size_t const maxOffset = maxOffbase - ZSTD_REP_NUM; ++ assert(ZSTD_highbit32((U32)maxOffbase) == STREAM_ACCUMULATOR_MIN); ++ return maxOffset; ++ } + } +-#endif + + size_t + ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, const int frame, const streaming_operation streaming) ++ const void* src, size_t srcSize, const streaming_operation streaming) + { /* blockType == blockCompressed */ + const BYTE* ip = (const BYTE*)src; +- /* isLongOffset must be true if there are long offsets. +- * Offsets are long if they are larger than 2^STREAM_ACCUMULATOR_MIN. +- * We don't expect that to be the case in 64-bit mode. +- * In block mode, window size is not known, so we have to be conservative. +- * (note: but it could be evaluated from current-lowLimit) +- */ +- ZSTD_longOffset_e const isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (!frame || (dctx->fParams.windowSize > (1ULL << STREAM_ACCUMULATOR_MIN)))); +- DEBUGLOG(5, "ZSTD_decompressBlock_internal (size : %u)", (U32)srcSize); +- +- RETURN_ERROR_IF(srcSize >= ZSTD_BLOCKSIZE_MAX, srcSize_wrong, ""); ++ DEBUGLOG(5, "ZSTD_decompressBlock_internal (cSize : %u)", (unsigned)srcSize); ++ ++ /* Note : the wording of the specification ++ * allows compressed block to be sized exactly ZSTD_blockSizeMax(dctx). ++ * This generally does not happen, as it makes little sense, ++ * since an uncompressed block would feature same size and have no decompression cost. ++ * Also, note that decoder from reference libzstd before < v1.5.4 ++ * would consider this edge case as an error. ++ * As a consequence, avoid generating compressed blocks of size ZSTD_blockSizeMax(dctx) ++ * for broader compatibility with the deployed ecosystem of zstd decoders */ ++ RETURN_ERROR_IF(srcSize > ZSTD_blockSizeMax(dctx), srcSize_wrong, ""); + + /* Decode literals section */ + { size_t const litCSize = ZSTD_decodeLiteralsBlock(dctx, src, srcSize, dst, dstCapacity, streaming); +- DEBUGLOG(5, "ZSTD_decodeLiteralsBlock : %u", (U32)litCSize); ++ DEBUGLOG(5, "ZSTD_decodeLiteralsBlock : cSize=%u, nbLiterals=%zu", (U32)litCSize, dctx->litSize); + if (ZSTD_isError(litCSize)) return litCSize; + ip += litCSize; + srcSize -= litCSize; +@@ -2001,6 +2097,23 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, + + /* Build Decoding Tables */ + { ++ /* Compute the maximum block size, which must also work when !frame and fParams are unset. ++ * Additionally, take the min with dstCapacity to ensure that the totalHistorySize fits in a size_t. ++ */ ++ size_t const blockSizeMax = MIN(dstCapacity, ZSTD_blockSizeMax(dctx)); ++ size_t const totalHistorySize = ZSTD_totalHistorySize(ZSTD_maybeNullPtrAdd((BYTE*)dst, blockSizeMax), (BYTE const*)dctx->virtualStart); ++ /* isLongOffset must be true if there are long offsets. ++ * Offsets are long if they are larger than ZSTD_maxShortOffset(). ++ * We don't expect that to be the case in 64-bit mode. ++ * ++ * We check here to see if our history is large enough to allow long offsets. ++ * If it isn't, then we can't possible have (valid) long offsets. If the offset ++ * is invalid, then it is okay to read it incorrectly. ++ * ++ * If isLongOffsets is true, then we will later check our decoding table to see ++ * if it is even possible to generate long offsets. ++ */ ++ ZSTD_longOffset_e isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (totalHistorySize > ZSTD_maxShortOffset())); + /* These macros control at build-time which decompressor implementation + * we use. If neither is defined, we do some inspection and dispatch at + * runtime. +@@ -2008,6 +2121,11 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, + #if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ + !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) + int usePrefetchDecoder = dctx->ddictIsCold; ++#else ++ /* Set to 1 to avoid computing offset info if we don't need to. ++ * Otherwise this value is ignored. ++ */ ++ int usePrefetchDecoder = 1; + #endif + int nbSeq; + size_t const seqHSize = ZSTD_decodeSeqHeaders(dctx, &nbSeq, ip, srcSize); +@@ -2015,40 +2133,55 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, + ip += seqHSize; + srcSize -= seqHSize; + +- RETURN_ERROR_IF(dst == NULL && nbSeq > 0, dstSize_tooSmall, "NULL not handled"); ++ RETURN_ERROR_IF((dst == NULL || dstCapacity == 0) && nbSeq > 0, dstSize_tooSmall, "NULL not handled"); ++ RETURN_ERROR_IF(MEM_64bits() && sizeof(size_t) == sizeof(void*) && (size_t)(-1) - (size_t)dst < (size_t)(1 << 20), dstSize_tooSmall, ++ "invalid dst"); + +-#if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ +- !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) +- if ( !usePrefetchDecoder +- && (!frame || (dctx->fParams.windowSize > (1<<24))) +- && (nbSeq>ADVANCED_SEQS) ) { /* could probably use a larger nbSeq limit */ +- U32 const shareLongOffsets = ZSTD_getLongOffsetsShare(dctx->OFTptr); +- U32 const minShare = MEM_64bits() ? 7 : 20; /* heuristic values, correspond to 2.73% and 7.81% */ +- usePrefetchDecoder = (shareLongOffsets >= minShare); ++ /* If we could potentially have long offsets, or we might want to use the prefetch decoder, ++ * compute information about the share of long offsets, and the maximum nbAdditionalBits. ++ * NOTE: could probably use a larger nbSeq limit ++ */ ++ if (isLongOffset || (!usePrefetchDecoder && (totalHistorySize > (1u << 24)) && (nbSeq > 8))) { ++ ZSTD_OffsetInfo const info = ZSTD_getOffsetInfo(dctx->OFTptr, nbSeq); ++ if (isLongOffset && info.maxNbAdditionalBits <= STREAM_ACCUMULATOR_MIN) { ++ /* If isLongOffset, but the maximum number of additional bits that we see in our table is small ++ * enough, then we know it is impossible to have too long an offset in this block, so we can ++ * use the regular offset decoder. ++ */ ++ isLongOffset = ZSTD_lo_isRegularOffset; ++ } ++ if (!usePrefetchDecoder) { ++ U32 const minShare = MEM_64bits() ? 7 : 20; /* heuristic values, correspond to 2.73% and 7.81% */ ++ usePrefetchDecoder = (info.longOffsetShare >= minShare); ++ } + } +-#endif + + dctx->ddictIsCold = 0; + + #if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ + !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) +- if (usePrefetchDecoder) ++ if (usePrefetchDecoder) { ++#else ++ (void)usePrefetchDecoder; ++ { + #endif + #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT +- return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); + #endif ++ } + + #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG + /* else */ + if (dctx->litBufferLocation == ZSTD_split) +- return ZSTD_decompressSequencesSplitLitBuffer(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequencesSplitLitBuffer(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); + else +- return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); ++ return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); + #endif + } + } + + ++ZSTD_ALLOW_POINTER_OVERFLOW_ATTR + void ZSTD_checkContinuity(ZSTD_DCtx* dctx, const void* dst, size_t dstSize) + { + if (dst != dctx->previousDstEnd && dstSize > 0) { /* not contiguous */ +@@ -2060,13 +2193,24 @@ void ZSTD_checkContinuity(ZSTD_DCtx* dctx, const void* dst, size_t dstSize) + } + + +-size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, +- void* dst, size_t dstCapacity, +- const void* src, size_t srcSize) ++size_t ZSTD_decompressBlock_deprecated(ZSTD_DCtx* dctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize) + { + size_t dSize; ++ dctx->isFrameDecompression = 0; + ZSTD_checkContinuity(dctx, dst, dstCapacity); +- dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 0, not_streaming); ++ dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, not_streaming); ++ FORWARD_IF_ERROR(dSize, ""); + dctx->previousDstEnd = (char*)dst + dSize; + return dSize; + } ++ ++ ++/* NOTE: Must just wrap ZSTD_decompressBlock_deprecated() */ ++size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize) ++{ ++ return ZSTD_decompressBlock_deprecated(dctx, dst, dstCapacity, src, srcSize); ++} +diff --git a/lib/zstd/decompress/zstd_decompress_block.h b/lib/zstd/decompress/zstd_decompress_block.h +index 3d2d57a5d25a..becffbd89364 100644 +--- a/lib/zstd/decompress/zstd_decompress_block.h ++++ b/lib/zstd/decompress/zstd_decompress_block.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -47,7 +48,7 @@ typedef enum { + */ + size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, +- const void* src, size_t srcSize, const int frame, const streaming_operation streaming); ++ const void* src, size_t srcSize, const streaming_operation streaming); + + /* ZSTD_buildFSETable() : + * generate FSE decoding table for one symbol (ll, ml or off) +@@ -64,5 +65,10 @@ void ZSTD_buildFSETable(ZSTD_seqSymbol* dt, + unsigned tableLog, void* wksp, size_t wkspSize, + int bmi2); + ++/* Internal definition of ZSTD_decompressBlock() to avoid deprecation warnings. */ ++size_t ZSTD_decompressBlock_deprecated(ZSTD_DCtx* dctx, ++ void* dst, size_t dstCapacity, ++ const void* src, size_t srcSize); ++ + + #endif /* ZSTD_DEC_BLOCK_H */ +diff --git a/lib/zstd/decompress/zstd_decompress_internal.h b/lib/zstd/decompress/zstd_decompress_internal.h +index 98102edb6a83..0f02526be774 100644 +--- a/lib/zstd/decompress/zstd_decompress_internal.h ++++ b/lib/zstd/decompress/zstd_decompress_internal.h +@@ -1,5 +1,6 @@ ++/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Yann Collet, Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -75,12 +76,13 @@ static UNUSED_ATTR const U32 ML_base[MaxML+1] = { + + #define ZSTD_BUILD_FSE_TABLE_WKSP_SIZE (sizeof(S16) * (MaxSeq + 1) + (1u << MaxFSELog) + sizeof(U64)) + #define ZSTD_BUILD_FSE_TABLE_WKSP_SIZE_U32 ((ZSTD_BUILD_FSE_TABLE_WKSP_SIZE + sizeof(U32) - 1) / sizeof(U32)) ++#define ZSTD_HUFFDTABLE_CAPACITY_LOG 12 + + typedef struct { + ZSTD_seqSymbol LLTable[SEQSYMBOL_TABLE_SIZE(LLFSELog)]; /* Note : Space reserved for FSE Tables */ + ZSTD_seqSymbol OFTable[SEQSYMBOL_TABLE_SIZE(OffFSELog)]; /* is also used as temporary workspace while building hufTable during DDict creation */ + ZSTD_seqSymbol MLTable[SEQSYMBOL_TABLE_SIZE(MLFSELog)]; /* and therefore must be at least HUF_DECOMPRESS_WORKSPACE_SIZE large */ +- HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */ ++ HUF_DTable hufTable[HUF_DTABLE_SIZE(ZSTD_HUFFDTABLE_CAPACITY_LOG)]; /* can accommodate HUF_decompress4X */ + U32 rep[ZSTD_REP_NUM]; + U32 workspace[ZSTD_BUILD_FSE_TABLE_WKSP_SIZE_U32]; + } ZSTD_entropyDTables_t; +@@ -152,6 +154,7 @@ struct ZSTD_DCtx_s + size_t litSize; + size_t rleSize; + size_t staticSize; ++ int isFrameDecompression; + #if DYNAMIC_BMI2 != 0 + int bmi2; /* == 1 if the CPU supports BMI2 and 0 otherwise. CPU support is determined dynamically once per context lifetime. */ + #endif +@@ -164,6 +167,8 @@ struct ZSTD_DCtx_s + ZSTD_dictUses_e dictUses; + ZSTD_DDictHashSet* ddictSet; /* Hash set for multiple ddicts */ + ZSTD_refMultipleDDicts_e refMultipleDDicts; /* User specified: if == 1, will allow references to multiple DDicts. Default == 0 (disabled) */ ++ int disableHufAsm; ++ int maxBlockSizeParam; + + /* streaming */ + ZSTD_dStreamStage streamStage; +diff --git a/lib/zstd/decompress_sources.h b/lib/zstd/decompress_sources.h +index a06ca187aab5..8a47eb2a4514 100644 +--- a/lib/zstd/decompress_sources.h ++++ b/lib/zstd/decompress_sources.h +@@ -1,6 +1,6 @@ + /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/zstd_common_module.c b/lib/zstd/zstd_common_module.c +index 22686e367e6f..466828e35752 100644 +--- a/lib/zstd/zstd_common_module.c ++++ b/lib/zstd/zstd_common_module.c +@@ -1,6 +1,6 @@ + // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -24,9 +24,6 @@ EXPORT_SYMBOL_GPL(HUF_readStats_wksp); + EXPORT_SYMBOL_GPL(ZSTD_isError); + EXPORT_SYMBOL_GPL(ZSTD_getErrorName); + EXPORT_SYMBOL_GPL(ZSTD_getErrorCode); +-EXPORT_SYMBOL_GPL(ZSTD_customMalloc); +-EXPORT_SYMBOL_GPL(ZSTD_customCalloc); +-EXPORT_SYMBOL_GPL(ZSTD_customFree); + + MODULE_LICENSE("Dual BSD/GPL"); + MODULE_DESCRIPTION("Zstd Common"); +diff --git a/lib/zstd/zstd_compress_module.c b/lib/zstd/zstd_compress_module.c +index 04e1b5c01d9b..8ecf43226af2 100644 +--- a/lib/zstd/zstd_compress_module.c ++++ b/lib/zstd/zstd_compress_module.c +@@ -1,6 +1,6 @@ + // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +diff --git a/lib/zstd/zstd_decompress_module.c b/lib/zstd/zstd_decompress_module.c +index f4ed952ed485..7d31518e9d5a 100644 +--- a/lib/zstd/zstd_decompress_module.c ++++ b/lib/zstd/zstd_decompress_module.c +@@ -1,6 +1,6 @@ + // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause + /* +- * Copyright (c) Facebook, Inc. ++ * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the +@@ -77,7 +77,7 @@ EXPORT_SYMBOL(zstd_init_dstream); + + size_t zstd_reset_dstream(zstd_dstream *dstream) + { +- return ZSTD_resetDStream(dstream); ++ return ZSTD_DCtx_reset(dstream, ZSTD_reset_session_only); + } + EXPORT_SYMBOL(zstd_reset_dstream); + +-- +2.46.0 diff --git a/patches/0002-sched-ext.patch b/patches/0002-sched-ext.patch new file mode 100644 index 0000000..77a6871 --- /dev/null +++ b/patches/0002-sched-ext.patch @@ -0,0 +1,15304 @@ +From a357af32f8cf89c5e5c51afd21ae57011ac02f19 Mon Sep 17 00:00:00 2001 +From: Peter Jung +Date: Sat, 3 Aug 2024 22:05:51 +0200 +Subject: [PATCH] sched-ext + +Signed-off-by: Peter Jung +--- + Documentation/scheduler/index.rst | 1 + + Documentation/scheduler/sched-ext.rst | 316 + + MAINTAINERS | 13 + + drivers/tty/sysrq.c | 1 + + include/asm-generic/vmlinux.lds.h | 1 + + include/linux/cgroup.h | 4 +- + include/linux/sched.h | 5 + + include/linux/sched/ext.h | 204 + + include/linux/sched/task.h | 3 +- + include/trace/events/sched_ext.h | 32 + + include/uapi/linux/sched.h | 1 + + init/init_task.c | 12 + + kernel/Kconfig.preempt | 26 +- + kernel/fork.c | 17 +- + kernel/sched/build_policy.c | 10 + + kernel/sched/core.c | 231 +- + kernel/sched/cpufreq_schedutil.c | 50 +- + kernel/sched/debug.c | 3 + + kernel/sched/ext.c | 6532 +++++++++++++++++ + kernel/sched/ext.h | 69 + + kernel/sched/fair.c | 22 +- + kernel/sched/idle.c | 2 + + kernel/sched/sched.h | 171 +- + lib/dump_stack.c | 1 + + tools/Makefile | 10 +- + tools/sched_ext/.gitignore | 2 + + tools/sched_ext/Makefile | 246 + + tools/sched_ext/README.md | 258 + + .../sched_ext/include/bpf-compat/gnu/stubs.h | 11 + + tools/sched_ext/include/scx/common.bpf.h | 401 + + tools/sched_ext/include/scx/common.h | 75 + + tools/sched_ext/include/scx/compat.bpf.h | 28 + + tools/sched_ext/include/scx/compat.h | 187 + + tools/sched_ext/include/scx/user_exit_info.h | 111 + + tools/sched_ext/scx_central.bpf.c | 361 + + tools/sched_ext/scx_central.c | 135 + + tools/sched_ext/scx_qmap.bpf.c | 706 ++ + tools/sched_ext/scx_qmap.c | 144 + + tools/sched_ext/scx_show_state.py | 39 + + tools/sched_ext/scx_simple.bpf.c | 156 + + tools/sched_ext/scx_simple.c | 107 + + tools/testing/selftests/sched_ext/.gitignore | 6 + + tools/testing/selftests/sched_ext/Makefile | 218 + + tools/testing/selftests/sched_ext/config | 9 + + .../selftests/sched_ext/create_dsq.bpf.c | 58 + + .../testing/selftests/sched_ext/create_dsq.c | 57 + + .../sched_ext/ddsp_bogus_dsq_fail.bpf.c | 42 + + .../selftests/sched_ext/ddsp_bogus_dsq_fail.c | 57 + + .../sched_ext/ddsp_vtimelocal_fail.bpf.c | 39 + + .../sched_ext/ddsp_vtimelocal_fail.c | 56 + + .../selftests/sched_ext/dsp_local_on.bpf.c | 65 + + .../selftests/sched_ext/dsp_local_on.c | 58 + + .../sched_ext/enq_last_no_enq_fails.bpf.c | 21 + + .../sched_ext/enq_last_no_enq_fails.c | 60 + + .../sched_ext/enq_select_cpu_fails.bpf.c | 43 + + .../sched_ext/enq_select_cpu_fails.c | 61 + + tools/testing/selftests/sched_ext/exit.bpf.c | 84 + + tools/testing/selftests/sched_ext/exit.c | 55 + + tools/testing/selftests/sched_ext/exit_test.h | 20 + + .../testing/selftests/sched_ext/hotplug.bpf.c | 61 + + tools/testing/selftests/sched_ext/hotplug.c | 168 + + .../selftests/sched_ext/hotplug_test.h | 15 + + .../sched_ext/init_enable_count.bpf.c | 53 + + .../selftests/sched_ext/init_enable_count.c | 166 + + .../testing/selftests/sched_ext/maximal.bpf.c | 132 + + tools/testing/selftests/sched_ext/maximal.c | 51 + + .../selftests/sched_ext/maybe_null.bpf.c | 36 + + .../testing/selftests/sched_ext/maybe_null.c | 49 + + .../sched_ext/maybe_null_fail_dsp.bpf.c | 25 + + .../sched_ext/maybe_null_fail_yld.bpf.c | 28 + + .../testing/selftests/sched_ext/minimal.bpf.c | 21 + + tools/testing/selftests/sched_ext/minimal.c | 58 + + .../selftests/sched_ext/prog_run.bpf.c | 33 + + tools/testing/selftests/sched_ext/prog_run.c | 78 + + .../testing/selftests/sched_ext/reload_loop.c | 75 + + tools/testing/selftests/sched_ext/runner.c | 201 + + tools/testing/selftests/sched_ext/scx_test.h | 131 + + .../selftests/sched_ext/select_cpu_dfl.bpf.c | 40 + + .../selftests/sched_ext/select_cpu_dfl.c | 72 + + .../sched_ext/select_cpu_dfl_nodispatch.bpf.c | 89 + + .../sched_ext/select_cpu_dfl_nodispatch.c | 72 + + .../sched_ext/select_cpu_dispatch.bpf.c | 41 + + .../selftests/sched_ext/select_cpu_dispatch.c | 70 + + .../select_cpu_dispatch_bad_dsq.bpf.c | 37 + + .../sched_ext/select_cpu_dispatch_bad_dsq.c | 56 + + .../select_cpu_dispatch_dbl_dsp.bpf.c | 38 + + .../sched_ext/select_cpu_dispatch_dbl_dsp.c | 56 + + .../sched_ext/select_cpu_vtime.bpf.c | 92 + + .../selftests/sched_ext/select_cpu_vtime.c | 59 + + .../selftests/sched_ext/test_example.c | 49 + + tools/testing/selftests/sched_ext/util.c | 71 + + tools/testing/selftests/sched_ext/util.h | 13 + + 92 files changed, 13845 insertions(+), 104 deletions(-) + create mode 100644 Documentation/scheduler/sched-ext.rst + create mode 100644 include/linux/sched/ext.h + create mode 100644 include/trace/events/sched_ext.h + create mode 100644 kernel/sched/ext.c + create mode 100644 kernel/sched/ext.h + create mode 100644 tools/sched_ext/.gitignore + create mode 100644 tools/sched_ext/Makefile + create mode 100644 tools/sched_ext/README.md + create mode 100644 tools/sched_ext/include/bpf-compat/gnu/stubs.h + create mode 100644 tools/sched_ext/include/scx/common.bpf.h + create mode 100644 tools/sched_ext/include/scx/common.h + create mode 100644 tools/sched_ext/include/scx/compat.bpf.h + create mode 100644 tools/sched_ext/include/scx/compat.h + create mode 100644 tools/sched_ext/include/scx/user_exit_info.h + create mode 100644 tools/sched_ext/scx_central.bpf.c + create mode 100644 tools/sched_ext/scx_central.c + create mode 100644 tools/sched_ext/scx_qmap.bpf.c + create mode 100644 tools/sched_ext/scx_qmap.c + create mode 100644 tools/sched_ext/scx_show_state.py + create mode 100644 tools/sched_ext/scx_simple.bpf.c + create mode 100644 tools/sched_ext/scx_simple.c + create mode 100644 tools/testing/selftests/sched_ext/.gitignore + create mode 100644 tools/testing/selftests/sched_ext/Makefile + create mode 100644 tools/testing/selftests/sched_ext/config + create mode 100644 tools/testing/selftests/sched_ext/create_dsq.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/create_dsq.c + create mode 100644 tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.c + create mode 100644 tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.c + create mode 100644 tools/testing/selftests/sched_ext/dsp_local_on.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/dsp_local_on.c + create mode 100644 tools/testing/selftests/sched_ext/enq_last_no_enq_fails.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/enq_last_no_enq_fails.c + create mode 100644 tools/testing/selftests/sched_ext/enq_select_cpu_fails.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/enq_select_cpu_fails.c + create mode 100644 tools/testing/selftests/sched_ext/exit.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/exit.c + create mode 100644 tools/testing/selftests/sched_ext/exit_test.h + create mode 100644 tools/testing/selftests/sched_ext/hotplug.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/hotplug.c + create mode 100644 tools/testing/selftests/sched_ext/hotplug_test.h + create mode 100644 tools/testing/selftests/sched_ext/init_enable_count.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/init_enable_count.c + create mode 100644 tools/testing/selftests/sched_ext/maximal.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/maximal.c + create mode 100644 tools/testing/selftests/sched_ext/maybe_null.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/maybe_null.c + create mode 100644 tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/minimal.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/minimal.c + create mode 100644 tools/testing/selftests/sched_ext/prog_run.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/prog_run.c + create mode 100644 tools/testing/selftests/sched_ext/reload_loop.c + create mode 100644 tools/testing/selftests/sched_ext/runner.c + create mode 100644 tools/testing/selftests/sched_ext/scx_test.h + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dfl.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dfl.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dispatch.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dispatch.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_vtime.bpf.c + create mode 100644 tools/testing/selftests/sched_ext/select_cpu_vtime.c + create mode 100644 tools/testing/selftests/sched_ext/test_example.c + create mode 100644 tools/testing/selftests/sched_ext/util.c + create mode 100644 tools/testing/selftests/sched_ext/util.h + +diff --git a/Documentation/scheduler/index.rst b/Documentation/scheduler/index.rst +index 43bd8a145b7a..0611dc3dda8e 100644 +--- a/Documentation/scheduler/index.rst ++++ b/Documentation/scheduler/index.rst +@@ -20,6 +20,7 @@ Scheduler + sched-nice-design + sched-rt-group + sched-stats ++ sched-ext + sched-debug + + text_files +diff --git a/Documentation/scheduler/sched-ext.rst b/Documentation/scheduler/sched-ext.rst +new file mode 100644 +index 000000000000..a707d2181a77 +--- /dev/null ++++ b/Documentation/scheduler/sched-ext.rst +@@ -0,0 +1,316 @@ ++========================== ++Extensible Scheduler Class ++========================== ++ ++sched_ext is a scheduler class whose behavior can be defined by a set of BPF ++programs - the BPF scheduler. ++ ++* sched_ext exports a full scheduling interface so that any scheduling ++ algorithm can be implemented on top. ++ ++* The BPF scheduler can group CPUs however it sees fit and schedule them ++ together, as tasks aren't tied to specific CPUs at the time of wakeup. ++ ++* The BPF scheduler can be turned on and off dynamically anytime. ++ ++* The system integrity is maintained no matter what the BPF scheduler does. ++ The default scheduling behavior is restored anytime an error is detected, ++ a runnable task stalls, or on invoking the SysRq key sequence ++ :kbd:`SysRq-S`. ++ ++* When the BPF scheduler triggers an error, debug information is dumped to ++ aid debugging. The debug dump is passed to and printed out by the ++ scheduler binary. The debug dump can also be accessed through the ++ `sched_ext_dump` tracepoint. The SysRq key sequence :kbd:`SysRq-D` ++ triggers a debug dump. This doesn't terminate the BPF scheduler and can ++ only be read through the tracepoint. ++ ++Switching to and from sched_ext ++=============================== ++ ++``CONFIG_SCHED_CLASS_EXT`` is the config option to enable sched_ext and ++``tools/sched_ext`` contains the example schedulers. The following config ++options should be enabled to use sched_ext: ++ ++.. code-block:: none ++ ++ CONFIG_BPF=y ++ CONFIG_SCHED_CLASS_EXT=y ++ CONFIG_BPF_SYSCALL=y ++ CONFIG_BPF_JIT=y ++ CONFIG_DEBUG_INFO_BTF=y ++ CONFIG_BPF_JIT_ALWAYS_ON=y ++ CONFIG_BPF_JIT_DEFAULT_ON=y ++ CONFIG_PAHOLE_HAS_SPLIT_BTF=y ++ CONFIG_PAHOLE_HAS_BTF_TAG=y ++ ++sched_ext is used only when the BPF scheduler is loaded and running. ++ ++If a task explicitly sets its scheduling policy to ``SCHED_EXT``, it will be ++treated as ``SCHED_NORMAL`` and scheduled by CFS until the BPF scheduler is ++loaded. ++ ++When the BPF scheduler is loaded and ``SCX_OPS_SWITCH_PARTIAL`` is not set ++in ``ops->flags``, all ``SCHED_NORMAL``, ``SCHED_BATCH``, ``SCHED_IDLE``, and ++``SCHED_EXT`` tasks are scheduled by sched_ext. ++ ++However, when the BPF scheduler is loaded and ``SCX_OPS_SWITCH_PARTIAL`` is ++set in ``ops->flags``, only tasks with the ``SCHED_EXT`` policy are scheduled ++by sched_ext, while tasks with ``SCHED_NORMAL``, ``SCHED_BATCH`` and ++``SCHED_IDLE`` policies are scheduled by CFS. ++ ++Terminating the sched_ext scheduler program, triggering :kbd:`SysRq-S`, or ++detection of any internal error including stalled runnable tasks aborts the ++BPF scheduler and reverts all tasks back to CFS. ++ ++.. code-block:: none ++ ++ # make -j16 -C tools/sched_ext ++ # tools/sched_ext/scx_simple ++ local=0 global=3 ++ local=5 global=24 ++ local=9 global=44 ++ local=13 global=56 ++ local=17 global=72 ++ ^CEXIT: BPF scheduler unregistered ++ ++The current status of the BPF scheduler can be determined as follows: ++ ++.. code-block:: none ++ ++ # cat /sys/kernel/sched_ext/state ++ enabled ++ # cat /sys/kernel/sched_ext/root/ops ++ simple ++ ++``tools/sched_ext/scx_show_state.py`` is a drgn script which shows more ++detailed information: ++ ++.. code-block:: none ++ ++ # tools/sched_ext/scx_show_state.py ++ ops : simple ++ enabled : 1 ++ switching_all : 1 ++ switched_all : 1 ++ enable_state : enabled (2) ++ bypass_depth : 0 ++ nr_rejected : 0 ++ ++If ``CONFIG_SCHED_DEBUG`` is set, whether a given task is on sched_ext can ++be determined as follows: ++ ++.. code-block:: none ++ ++ # grep ext /proc/self/sched ++ ext.enabled : 1 ++ ++The Basics ++========== ++ ++Userspace can implement an arbitrary BPF scheduler by loading a set of BPF ++programs that implement ``struct sched_ext_ops``. The only mandatory field ++is ``ops.name`` which must be a valid BPF object name. All operations are ++optional. The following modified excerpt is from ++``tools/sched_ext/scx_simple.bpf.c`` showing a minimal global FIFO scheduler. ++ ++.. code-block:: c ++ ++ /* ++ * Decide which CPU a task should be migrated to before being ++ * enqueued (either at wakeup, fork time, or exec time). If an ++ * idle core is found by the default ops.select_cpu() implementation, ++ * then dispatch the task directly to SCX_DSQ_LOCAL and skip the ++ * ops.enqueue() callback. ++ * ++ * Note that this implementation has exactly the same behavior as the ++ * default ops.select_cpu implementation. The behavior of the scheduler ++ * would be exactly same if the implementation just didn't define the ++ * simple_select_cpu() struct_ops prog. ++ */ ++ s32 BPF_STRUCT_OPS(simple_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++ { ++ s32 cpu; ++ /* Need to initialize or the BPF verifier will reject the program */ ++ bool direct = false; ++ ++ cpu = scx_bpf_select_cpu_dfl(p, prev_cpu, wake_flags, &direct); ++ ++ if (direct) ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL, SCX_SLICE_DFL, 0); ++ ++ return cpu; ++ } ++ ++ /* ++ * Do a direct dispatch of a task to the global DSQ. This ops.enqueue() ++ * callback will only be invoked if we failed to find a core to dispatch ++ * to in ops.select_cpu() above. ++ * ++ * Note that this implementation has exactly the same behavior as the ++ * default ops.enqueue implementation, which just dispatches the task ++ * to SCX_DSQ_GLOBAL. The behavior of the scheduler would be exactly same ++ * if the implementation just didn't define the simple_enqueue struct_ops ++ * prog. ++ */ ++ void BPF_STRUCT_OPS(simple_enqueue, struct task_struct *p, u64 enq_flags) ++ { ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, enq_flags); ++ } ++ ++ s32 BPF_STRUCT_OPS_SLEEPABLE(simple_init) ++ { ++ /* ++ * By default, all SCHED_EXT, SCHED_OTHER, SCHED_IDLE, and ++ * SCHED_BATCH tasks should use sched_ext. ++ */ ++ return 0; ++ } ++ ++ void BPF_STRUCT_OPS(simple_exit, struct scx_exit_info *ei) ++ { ++ exit_type = ei->type; ++ } ++ ++ SEC(".struct_ops") ++ struct sched_ext_ops simple_ops = { ++ .select_cpu = (void *)simple_select_cpu, ++ .enqueue = (void *)simple_enqueue, ++ .init = (void *)simple_init, ++ .exit = (void *)simple_exit, ++ .name = "simple", ++ }; ++ ++Dispatch Queues ++--------------- ++ ++To match the impedance between the scheduler core and the BPF scheduler, ++sched_ext uses DSQs (dispatch queues) which can operate as both a FIFO and a ++priority queue. By default, there is one global FIFO (``SCX_DSQ_GLOBAL``), ++and one local dsq per CPU (``SCX_DSQ_LOCAL``). The BPF scheduler can manage ++an arbitrary number of dsq's using ``scx_bpf_create_dsq()`` and ++``scx_bpf_destroy_dsq()``. ++ ++A CPU always executes a task from its local DSQ. A task is "dispatched" to a ++DSQ. A non-local DSQ is "consumed" to transfer a task to the consuming CPU's ++local DSQ. ++ ++When a CPU is looking for the next task to run, if the local DSQ is not ++empty, the first task is picked. Otherwise, the CPU tries to consume the ++global DSQ. If that doesn't yield a runnable task either, ``ops.dispatch()`` ++is invoked. ++ ++Scheduling Cycle ++---------------- ++ ++The following briefly shows how a waking task is scheduled and executed. ++ ++1. When a task is waking up, ``ops.select_cpu()`` is the first operation ++ invoked. This serves two purposes. First, CPU selection optimization ++ hint. Second, waking up the selected CPU if idle. ++ ++ The CPU selected by ``ops.select_cpu()`` is an optimization hint and not ++ binding. The actual decision is made at the last step of scheduling. ++ However, there is a small performance gain if the CPU ++ ``ops.select_cpu()`` returns matches the CPU the task eventually runs on. ++ ++ A side-effect of selecting a CPU is waking it up from idle. While a BPF ++ scheduler can wake up any cpu using the ``scx_bpf_kick_cpu()`` helper, ++ using ``ops.select_cpu()`` judiciously can be simpler and more efficient. ++ ++ A task can be immediately dispatched to a DSQ from ``ops.select_cpu()`` by ++ calling ``scx_bpf_dispatch()``. If the task is dispatched to ++ ``SCX_DSQ_LOCAL`` from ``ops.select_cpu()``, it will be dispatched to the ++ local DSQ of whichever CPU is returned from ``ops.select_cpu()``. ++ Additionally, dispatching directly from ``ops.select_cpu()`` will cause the ++ ``ops.enqueue()`` callback to be skipped. ++ ++ Note that the scheduler core will ignore an invalid CPU selection, for ++ example, if it's outside the allowed cpumask of the task. ++ ++2. Once the target CPU is selected, ``ops.enqueue()`` is invoked (unless the ++ task was dispatched directly from ``ops.select_cpu()``). ``ops.enqueue()`` ++ can make one of the following decisions: ++ ++ * Immediately dispatch the task to either the global or local DSQ by ++ calling ``scx_bpf_dispatch()`` with ``SCX_DSQ_GLOBAL`` or ++ ``SCX_DSQ_LOCAL``, respectively. ++ ++ * Immediately dispatch the task to a custom DSQ by calling ++ ``scx_bpf_dispatch()`` with a DSQ ID which is smaller than 2^63. ++ ++ * Queue the task on the BPF side. ++ ++3. When a CPU is ready to schedule, it first looks at its local DSQ. If ++ empty, it then looks at the global DSQ. If there still isn't a task to ++ run, ``ops.dispatch()`` is invoked which can use the following two ++ functions to populate the local DSQ. ++ ++ * ``scx_bpf_dispatch()`` dispatches a task to a DSQ. Any target DSQ can ++ be used - ``SCX_DSQ_LOCAL``, ``SCX_DSQ_LOCAL_ON | cpu``, ++ ``SCX_DSQ_GLOBAL`` or a custom DSQ. While ``scx_bpf_dispatch()`` ++ currently can't be called with BPF locks held, this is being worked on ++ and will be supported. ``scx_bpf_dispatch()`` schedules dispatching ++ rather than performing them immediately. There can be up to ++ ``ops.dispatch_max_batch`` pending tasks. ++ ++ * ``scx_bpf_consume()`` tranfers a task from the specified non-local DSQ ++ to the dispatching DSQ. This function cannot be called with any BPF ++ locks held. ``scx_bpf_consume()`` flushes the pending dispatched tasks ++ before trying to consume the specified DSQ. ++ ++4. After ``ops.dispatch()`` returns, if there are tasks in the local DSQ, ++ the CPU runs the first one. If empty, the following steps are taken: ++ ++ * Try to consume the global DSQ. If successful, run the task. ++ ++ * If ``ops.dispatch()`` has dispatched any tasks, retry #3. ++ ++ * If the previous task is an SCX task and still runnable, keep executing ++ it (see ``SCX_OPS_ENQ_LAST``). ++ ++ * Go idle. ++ ++Note that the BPF scheduler can always choose to dispatch tasks immediately ++in ``ops.enqueue()`` as illustrated in the above simple example. If only the ++built-in DSQs are used, there is no need to implement ``ops.dispatch()`` as ++a task is never queued on the BPF scheduler and both the local and global ++DSQs are consumed automatically. ++ ++``scx_bpf_dispatch()`` queues the task on the FIFO of the target DSQ. Use ++``scx_bpf_dispatch_vtime()`` for the priority queue. Internal DSQs such as ++``SCX_DSQ_LOCAL`` and ``SCX_DSQ_GLOBAL`` do not support priority-queue ++dispatching, and must be dispatched to with ``scx_bpf_dispatch()``. See the ++function documentation and usage in ``tools/sched_ext/scx_simple.bpf.c`` for ++more information. ++ ++Where to Look ++============= ++ ++* ``include/linux/sched/ext.h`` defines the core data structures, ops table ++ and constants. ++ ++* ``kernel/sched/ext.c`` contains sched_ext core implementation and helpers. ++ The functions prefixed with ``scx_bpf_`` can be called from the BPF ++ scheduler. ++ ++* ``tools/sched_ext/`` hosts example BPF scheduler implementations. ++ ++ * ``scx_simple[.bpf].c``: Minimal global FIFO scheduler example using a ++ custom DSQ. ++ ++ * ``scx_qmap[.bpf].c``: A multi-level FIFO scheduler supporting five ++ levels of priority implemented with ``BPF_MAP_TYPE_QUEUE``. ++ ++ABI Instability ++=============== ++ ++The APIs provided by sched_ext to BPF schedulers programs have no stability ++guarantees. This includes the ops table callbacks and constants defined in ++``include/linux/sched/ext.h``, as well as the ``scx_bpf_`` kfuncs defined in ++``kernel/sched/ext.c``. ++ ++While we will attempt to provide a relatively stable API surface when ++possible, they are subject to change without warning between kernel ++versions. +diff --git a/MAINTAINERS b/MAINTAINERS +index 064156d69e75..78815a193485 100644 +--- a/MAINTAINERS ++++ b/MAINTAINERS +@@ -19945,6 +19945,19 @@ F: include/linux/wait.h + F: include/uapi/linux/sched.h + F: kernel/sched/ + ++SCHEDULER - SCHED_EXT ++R: Tejun Heo ++R: David Vernet ++L: linux-kernel@vger.kernel.org ++S: Maintained ++W: https://github.com/sched-ext/scx ++T: git://git.kernel.org/pub/scm/linux/kernel/git/tj/sched_ext.git ++F: include/linux/sched/ext.h ++F: kernel/sched/ext.h ++F: kernel/sched/ext.c ++F: tools/sched_ext/ ++F: tools/testing/selftests/sched_ext ++ + SCSI LIBSAS SUBSYSTEM + R: John Garry + R: Jason Yan +diff --git a/drivers/tty/sysrq.c b/drivers/tty/sysrq.c +index e5974b8239c9..167e877b8bef 100644 +--- a/drivers/tty/sysrq.c ++++ b/drivers/tty/sysrq.c +@@ -531,6 +531,7 @@ static const struct sysrq_key_op *sysrq_key_table[62] = { + NULL, /* P */ + NULL, /* Q */ + &sysrq_replay_logs_op, /* R */ ++ /* S: May be registered by sched_ext for resetting */ + NULL, /* S */ + NULL, /* T */ + NULL, /* U */ +diff --git a/include/asm-generic/vmlinux.lds.h b/include/asm-generic/vmlinux.lds.h +index 70bf1004076b..a8417d31e348 100644 +--- a/include/asm-generic/vmlinux.lds.h ++++ b/include/asm-generic/vmlinux.lds.h +@@ -133,6 +133,7 @@ + *(__dl_sched_class) \ + *(__rt_sched_class) \ + *(__fair_sched_class) \ ++ *(__ext_sched_class) \ + *(__idle_sched_class) \ + __sched_class_lowest = .; + +diff --git a/include/linux/cgroup.h b/include/linux/cgroup.h +index 2150ca60394b..3cdaec701600 100644 +--- a/include/linux/cgroup.h ++++ b/include/linux/cgroup.h +@@ -29,8 +29,6 @@ + + struct kernel_clone_args; + +-#ifdef CONFIG_CGROUPS +- + /* + * All weight knobs on the default hierarchy should use the following min, + * default and max values. The default value is the logarithmic center of +@@ -40,6 +38,8 @@ struct kernel_clone_args; + #define CGROUP_WEIGHT_DFL 100 + #define CGROUP_WEIGHT_MAX 10000 + ++#ifdef CONFIG_CGROUPS ++ + enum { + CSS_TASK_ITER_PROCS = (1U << 0), /* walk only threadgroup leaders */ + CSS_TASK_ITER_THREADED = (1U << 1), /* walk all threaded css_sets in the domain */ +diff --git a/include/linux/sched.h b/include/linux/sched.h +index 76214d7c819d..0f3a107bcd02 100644 +--- a/include/linux/sched.h ++++ b/include/linux/sched.h +@@ -80,6 +80,8 @@ struct task_group; + struct task_struct; + struct user_event_mm; + ++#include ++ + /* + * Task state bitmask. NOTE! These bits are also + * encoded in fs/proc/array.c: get_task_state(). +@@ -802,6 +804,9 @@ struct task_struct { + struct sched_rt_entity rt; + struct sched_dl_entity dl; + struct sched_dl_entity *dl_server; ++#ifdef CONFIG_SCHED_CLASS_EXT ++ struct sched_ext_entity scx; ++#endif + const struct sched_class *sched_class; + + #ifdef CONFIG_SCHED_CORE +diff --git a/include/linux/sched/ext.h b/include/linux/sched/ext.h +new file mode 100644 +index 000000000000..26e1c33bc844 +--- /dev/null ++++ b/include/linux/sched/ext.h +@@ -0,0 +1,204 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * BPF extensible scheduler class: Documentation/scheduler/sched-ext.rst ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#ifndef _LINUX_SCHED_EXT_H ++#define _LINUX_SCHED_EXT_H ++ ++#ifdef CONFIG_SCHED_CLASS_EXT ++ ++#include ++#include ++ ++enum scx_public_consts { ++ SCX_OPS_NAME_LEN = 128, ++ ++ SCX_SLICE_DFL = 20 * 1000000, /* 20ms */ ++ SCX_SLICE_INF = U64_MAX, /* infinite, implies nohz */ ++}; ++ ++/* ++ * DSQ (dispatch queue) IDs are 64bit of the format: ++ * ++ * Bits: [63] [62 .. 0] ++ * [ B] [ ID ] ++ * ++ * B: 1 for IDs for built-in DSQs, 0 for ops-created user DSQs ++ * ID: 63 bit ID ++ * ++ * Built-in IDs: ++ * ++ * Bits: [63] [62] [61..32] [31 .. 0] ++ * [ 1] [ L] [ R ] [ V ] ++ * ++ * 1: 1 for built-in DSQs. ++ * L: 1 for LOCAL_ON DSQ IDs, 0 for others ++ * V: For LOCAL_ON DSQ IDs, a CPU number. For others, a pre-defined value. ++ */ ++enum scx_dsq_id_flags { ++ SCX_DSQ_FLAG_BUILTIN = 1LLU << 63, ++ SCX_DSQ_FLAG_LOCAL_ON = 1LLU << 62, ++ ++ SCX_DSQ_INVALID = SCX_DSQ_FLAG_BUILTIN | 0, ++ SCX_DSQ_GLOBAL = SCX_DSQ_FLAG_BUILTIN | 1, ++ SCX_DSQ_LOCAL = SCX_DSQ_FLAG_BUILTIN | 2, ++ SCX_DSQ_LOCAL_ON = SCX_DSQ_FLAG_BUILTIN | SCX_DSQ_FLAG_LOCAL_ON, ++ SCX_DSQ_LOCAL_CPU_MASK = 0xffffffffLLU, ++}; ++ ++/* ++ * A dispatch queue (DSQ) can be either a FIFO or p->scx.dsq_vtime ordered ++ * queue. A built-in DSQ is always a FIFO. The built-in local DSQs are used to ++ * buffer between the scheduler core and the BPF scheduler. See the ++ * documentation for more details. ++ */ ++struct scx_dispatch_q { ++ raw_spinlock_t lock; ++ struct list_head list; /* tasks in dispatch order */ ++ struct rb_root priq; /* used to order by p->scx.dsq_vtime */ ++ u32 nr; ++ u32 seq; /* used by BPF iter */ ++ u64 id; ++ struct rhash_head hash_node; ++ struct llist_node free_node; ++ struct rcu_head rcu; ++}; ++ ++/* scx_entity.flags */ ++enum scx_ent_flags { ++ SCX_TASK_QUEUED = 1 << 0, /* on ext runqueue */ ++ SCX_TASK_BAL_KEEP = 1 << 1, /* balance decided to keep current */ ++ SCX_TASK_RESET_RUNNABLE_AT = 1 << 2, /* runnable_at should be reset */ ++ SCX_TASK_DEQD_FOR_SLEEP = 1 << 3, /* last dequeue was for SLEEP */ ++ ++ SCX_TASK_STATE_SHIFT = 8, /* bit 8 and 9 are used to carry scx_task_state */ ++ SCX_TASK_STATE_BITS = 2, ++ SCX_TASK_STATE_MASK = ((1 << SCX_TASK_STATE_BITS) - 1) << SCX_TASK_STATE_SHIFT, ++ ++ SCX_TASK_CURSOR = 1 << 31, /* iteration cursor, not a task */ ++}; ++ ++/* scx_entity.flags & SCX_TASK_STATE_MASK */ ++enum scx_task_state { ++ SCX_TASK_NONE, /* ops.init_task() not called yet */ ++ SCX_TASK_INIT, /* ops.init_task() succeeded, but task can be cancelled */ ++ SCX_TASK_READY, /* fully initialized, but not in sched_ext */ ++ SCX_TASK_ENABLED, /* fully initialized and in sched_ext */ ++ ++ SCX_TASK_NR_STATES, ++}; ++ ++/* scx_entity.dsq_flags */ ++enum scx_ent_dsq_flags { ++ SCX_TASK_DSQ_ON_PRIQ = 1 << 0, /* task is queued on the priority queue of a dsq */ ++}; ++ ++/* ++ * Mask bits for scx_entity.kf_mask. Not all kfuncs can be called from ++ * everywhere and the following bits track which kfunc sets are currently ++ * allowed for %current. This simple per-task tracking works because SCX ops ++ * nest in a limited way. BPF will likely implement a way to allow and disallow ++ * kfuncs depending on the calling context which will replace this manual ++ * mechanism. See scx_kf_allow(). ++ */ ++enum scx_kf_mask { ++ SCX_KF_UNLOCKED = 0, /* sleepable and not rq locked */ ++ /* ENQUEUE and DISPATCH may be nested inside CPU_RELEASE */ ++ SCX_KF_CPU_RELEASE = 1 << 0, /* ops.cpu_release() */ ++ /* ops.dequeue (in REST) may be nested inside DISPATCH */ ++ SCX_KF_DISPATCH = 1 << 1, /* ops.dispatch() */ ++ SCX_KF_ENQUEUE = 1 << 2, /* ops.enqueue() and ops.select_cpu() */ ++ SCX_KF_SELECT_CPU = 1 << 3, /* ops.select_cpu() */ ++ SCX_KF_REST = 1 << 4, /* other rq-locked operations */ ++ ++ __SCX_KF_RQ_LOCKED = SCX_KF_CPU_RELEASE | SCX_KF_DISPATCH | ++ SCX_KF_ENQUEUE | SCX_KF_SELECT_CPU | SCX_KF_REST, ++ __SCX_KF_TERMINAL = SCX_KF_ENQUEUE | SCX_KF_SELECT_CPU | SCX_KF_REST, ++}; ++ ++struct scx_dsq_list_node { ++ struct list_head node; ++ bool is_bpf_iter_cursor; ++}; ++ ++/* ++ * The following is embedded in task_struct and contains all fields necessary ++ * for a task to be scheduled by SCX. ++ */ ++struct sched_ext_entity { ++ struct scx_dispatch_q *dsq; ++ struct scx_dsq_list_node dsq_list; /* dispatch order */ ++ struct rb_node dsq_priq; /* p->scx.dsq_vtime order */ ++ u32 dsq_seq; ++ u32 dsq_flags; /* protected by DSQ lock */ ++ u32 flags; /* protected by rq lock */ ++ u32 weight; ++ s32 sticky_cpu; ++ s32 holding_cpu; ++ u32 kf_mask; /* see scx_kf_mask above */ ++ struct task_struct *kf_tasks[2]; /* see SCX_CALL_OP_TASK() */ ++ atomic_long_t ops_state; ++ ++ struct list_head runnable_node; /* rq->scx.runnable_list */ ++ unsigned long runnable_at; ++ ++#ifdef CONFIG_SCHED_CORE ++ u64 core_sched_at; /* see scx_prio_less() */ ++#endif ++ u64 ddsp_dsq_id; ++ u64 ddsp_enq_flags; ++ ++ /* BPF scheduler modifiable fields */ ++ ++ /* ++ * Runtime budget in nsecs. This is usually set through ++ * scx_bpf_dispatch() but can also be modified directly by the BPF ++ * scheduler. Automatically decreased by SCX as the task executes. On ++ * depletion, a scheduling event is triggered. ++ * ++ * This value is cleared to zero if the task is preempted by ++ * %SCX_KICK_PREEMPT and shouldn't be used to determine how long the ++ * task ran. Use p->se.sum_exec_runtime instead. ++ */ ++ u64 slice; ++ ++ /* ++ * Used to order tasks when dispatching to the vtime-ordered priority ++ * queue of a dsq. This is usually set through scx_bpf_dispatch_vtime() ++ * but can also be modified directly by the BPF scheduler. Modifying it ++ * while a task is queued on a dsq may mangle the ordering and is not ++ * recommended. ++ */ ++ u64 dsq_vtime; ++ ++ /* ++ * If set, reject future sched_setscheduler(2) calls updating the policy ++ * to %SCHED_EXT with -%EACCES. ++ * ++ * If set from ops.init_task() and the task's policy is already ++ * %SCHED_EXT, which can happen while the BPF scheduler is being loaded ++ * or by inhering the parent's policy during fork, the task's policy is ++ * rejected and forcefully reverted to %SCHED_NORMAL. The number of ++ * such events are reported through /sys/kernel/debug/sched_ext::nr_rejected. ++ */ ++ bool disallow; /* reject switching into SCX */ ++ ++ /* cold fields */ ++ /* must be the last field, see init_scx_entity() */ ++ struct list_head tasks_node; ++}; ++ ++void sched_ext_free(struct task_struct *p); ++void print_scx_info(const char *log_lvl, struct task_struct *p); ++ ++#else /* !CONFIG_SCHED_CLASS_EXT */ ++ ++static inline void sched_ext_free(struct task_struct *p) {} ++static inline void print_scx_info(const char *log_lvl, struct task_struct *p) {} ++ ++#endif /* CONFIG_SCHED_CLASS_EXT */ ++#endif /* _LINUX_SCHED_EXT_H */ +diff --git a/include/linux/sched/task.h b/include/linux/sched/task.h +index d362aacf9f89..4df2f9055587 100644 +--- a/include/linux/sched/task.h ++++ b/include/linux/sched/task.h +@@ -63,7 +63,8 @@ extern asmlinkage void schedule_tail(struct task_struct *prev); + extern void init_idle(struct task_struct *idle, int cpu); + + extern int sched_fork(unsigned long clone_flags, struct task_struct *p); +-extern void sched_cgroup_fork(struct task_struct *p, struct kernel_clone_args *kargs); ++extern int sched_cgroup_fork(struct task_struct *p, struct kernel_clone_args *kargs); ++extern void sched_cancel_fork(struct task_struct *p); + extern void sched_post_fork(struct task_struct *p); + extern void sched_dead(struct task_struct *p); + +diff --git a/include/trace/events/sched_ext.h b/include/trace/events/sched_ext.h +new file mode 100644 +index 000000000000..fe19da7315a9 +--- /dev/null ++++ b/include/trace/events/sched_ext.h +@@ -0,0 +1,32 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++#undef TRACE_SYSTEM ++#define TRACE_SYSTEM sched_ext ++ ++#if !defined(_TRACE_SCHED_EXT_H) || defined(TRACE_HEADER_MULTI_READ) ++#define _TRACE_SCHED_EXT_H ++ ++#include ++ ++TRACE_EVENT(sched_ext_dump, ++ ++ TP_PROTO(const char *line), ++ ++ TP_ARGS(line), ++ ++ TP_STRUCT__entry( ++ __string(line, line) ++ ), ++ ++ TP_fast_assign( ++ __assign_str(line); ++ ), ++ ++ TP_printk("%s", ++ __get_str(line) ++ ) ++); ++ ++#endif /* _TRACE_SCHED_EXT_H */ ++ ++/* This part must be outside protection */ ++#include +diff --git a/include/uapi/linux/sched.h b/include/uapi/linux/sched.h +index 3bac0a8ceab2..359a14cc76a4 100644 +--- a/include/uapi/linux/sched.h ++++ b/include/uapi/linux/sched.h +@@ -118,6 +118,7 @@ struct clone_args { + /* SCHED_ISO: reserved but not implemented yet */ + #define SCHED_IDLE 5 + #define SCHED_DEADLINE 6 ++#define SCHED_EXT 7 + + /* Can be ORed in to make sure the process is reverted back to SCHED_NORMAL on fork */ + #define SCHED_RESET_ON_FORK 0x40000000 +diff --git a/init/init_task.c b/init/init_task.c +index eeb110c65fe2..e222722e790b 100644 +--- a/init/init_task.c ++++ b/init/init_task.c +@@ -6,6 +6,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -98,6 +99,17 @@ struct task_struct init_task __aligned(L1_CACHE_BYTES) = { + #endif + #ifdef CONFIG_CGROUP_SCHED + .sched_task_group = &root_task_group, ++#endif ++#ifdef CONFIG_SCHED_CLASS_EXT ++ .scx = { ++ .dsq_list.node = LIST_HEAD_INIT(init_task.scx.dsq_list.node), ++ .sticky_cpu = -1, ++ .holding_cpu = -1, ++ .runnable_node = LIST_HEAD_INIT(init_task.scx.runnable_node), ++ .runnable_at = INITIAL_JIFFIES, ++ .ddsp_dsq_id = SCX_DSQ_INVALID, ++ .slice = SCX_SLICE_DFL, ++ }, + #endif + .ptraced = LIST_HEAD_INIT(init_task.ptraced), + .ptrace_entry = LIST_HEAD_INIT(init_task.ptrace_entry), +diff --git a/kernel/Kconfig.preempt b/kernel/Kconfig.preempt +index c2f1fd95a821..f3d140c3acc1 100644 +--- a/kernel/Kconfig.preempt ++++ b/kernel/Kconfig.preempt +@@ -133,4 +133,28 @@ config SCHED_CORE + which is the likely usage by Linux distributions, there should + be no measurable impact on performance. + +- ++config SCHED_CLASS_EXT ++ bool "Extensible Scheduling Class" ++ depends on BPF_SYSCALL && BPF_JIT && DEBUG_INFO_BTF ++ help ++ This option enables a new scheduler class sched_ext (SCX), which ++ allows scheduling policies to be implemented as BPF programs to ++ achieve the following: ++ ++ - Ease of experimentation and exploration: Enabling rapid ++ iteration of new scheduling policies. ++ - Customization: Building application-specific schedulers which ++ implement policies that are not applicable to general-purpose ++ schedulers. ++ - Rapid scheduler deployments: Non-disruptive swap outs of ++ scheduling policies in production environments. ++ ++ sched_ext leverages BPF struct_ops feature to define a structure ++ which exports function callbacks and flags to BPF programs that ++ wish to implement scheduling policies. The struct_ops structure ++ exported by sched_ext is struct sched_ext_ops, and is conceptually ++ similar to struct sched_class. ++ ++ For more information: ++ Documentation/scheduler/sched-ext.rst ++ https://github.com/sched-ext/scx +diff --git a/kernel/fork.c b/kernel/fork.c +index 18750b83c564..d973d23b3768 100644 +--- a/kernel/fork.c ++++ b/kernel/fork.c +@@ -23,6 +23,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -975,6 +976,7 @@ void __put_task_struct(struct task_struct *tsk) + WARN_ON(refcount_read(&tsk->usage)); + WARN_ON(tsk == current); + ++ sched_ext_free(tsk); + io_uring_free(tsk); + cgroup_free(tsk); + task_numa_free(tsk, true); +@@ -2371,7 +2373,7 @@ __latent_entropy struct task_struct *copy_process( + + retval = perf_event_init_task(p, clone_flags); + if (retval) +- goto bad_fork_cleanup_policy; ++ goto bad_fork_sched_cancel_fork; + retval = audit_alloc(p); + if (retval) + goto bad_fork_cleanup_perf; +@@ -2504,7 +2506,9 @@ __latent_entropy struct task_struct *copy_process( + * cgroup specific, it unconditionally needs to place the task on a + * runqueue. + */ +- sched_cgroup_fork(p, args); ++ retval = sched_cgroup_fork(p, args); ++ if (retval) ++ goto bad_fork_cancel_cgroup; + + /* + * From this point on we must avoid any synchronous user-space +@@ -2550,13 +2554,13 @@ __latent_entropy struct task_struct *copy_process( + /* Don't start children in a dying pid namespace */ + if (unlikely(!(ns_of_pid(pid)->pid_allocated & PIDNS_ADDING))) { + retval = -ENOMEM; +- goto bad_fork_cancel_cgroup; ++ goto bad_fork_core_free; + } + + /* Let kill terminate clone/fork in the middle */ + if (fatal_signal_pending(current)) { + retval = -EINTR; +- goto bad_fork_cancel_cgroup; ++ goto bad_fork_core_free; + } + + /* No more failure paths after this point. */ +@@ -2630,10 +2634,11 @@ __latent_entropy struct task_struct *copy_process( + + return p; + +-bad_fork_cancel_cgroup: ++bad_fork_core_free: + sched_core_free(p); + spin_unlock(¤t->sighand->siglock); + write_unlock_irq(&tasklist_lock); ++bad_fork_cancel_cgroup: + cgroup_cancel_fork(p, args); + bad_fork_put_pidfd: + if (clone_flags & CLONE_PIDFD) { +@@ -2672,6 +2677,8 @@ __latent_entropy struct task_struct *copy_process( + audit_free(p); + bad_fork_cleanup_perf: + perf_event_free_task(p); ++bad_fork_sched_cancel_fork: ++ sched_cancel_fork(p); + bad_fork_cleanup_policy: + lockdep_free_task(p); + #ifdef CONFIG_NUMA +diff --git a/kernel/sched/build_policy.c b/kernel/sched/build_policy.c +index d9dc9ab3773f..e7d539bb721e 100644 +--- a/kernel/sched/build_policy.c ++++ b/kernel/sched/build_policy.c +@@ -16,18 +16,25 @@ + #include + #include + #include ++#include + #include + #include + + #include + #include ++#include + #include ++#include + #include ++#include ++#include + #include + #include + #include + #include + #include ++#include ++#include + + #include + +@@ -52,3 +59,6 @@ + #include "cputime.c" + #include "deadline.c" + ++#ifdef CONFIG_SCHED_CLASS_EXT ++# include "ext.c" ++#endif +diff --git a/kernel/sched/core.c b/kernel/sched/core.c +index ebf21373f663..fb6276f74ee6 100644 +--- a/kernel/sched/core.c ++++ b/kernel/sched/core.c +@@ -168,7 +168,10 @@ static inline int __task_prio(const struct task_struct *p) + if (p->sched_class == &idle_sched_class) + return MAX_RT_PRIO + NICE_WIDTH; /* 140 */ + +- return MAX_RT_PRIO + MAX_NICE; /* 120, squash fair */ ++ if (task_on_scx(p)) ++ return MAX_RT_PRIO + MAX_NICE + 1; /* 120, squash ext */ ++ ++ return MAX_RT_PRIO + MAX_NICE; /* 119, squash fair */ + } + + /* +@@ -197,6 +200,11 @@ static inline bool prio_less(const struct task_struct *a, + if (pa == MAX_RT_PRIO + MAX_NICE) /* fair */ + return cfs_prio_less(a, b, in_fi); + ++#ifdef CONFIG_SCHED_CLASS_EXT ++ if (pa == MAX_RT_PRIO + MAX_NICE + 1) /* ext */ ++ return scx_prio_less(a, b, in_fi); ++#endif ++ + return false; + } + +@@ -1254,11 +1262,14 @@ bool sched_can_stop_tick(struct rq *rq) + return true; + + /* +- * If there are no DL,RR/FIFO tasks, there must only be CFS tasks left; +- * if there's more than one we need the tick for involuntary +- * preemption. ++ * If there are no DL,RR/FIFO tasks, there must only be CFS or SCX tasks ++ * left. For CFS, if there's more than one we need the tick for ++ * involuntary preemption. For SCX, ask. + */ +- if (rq->nr_running > 1) ++ if (!scx_switched_all() && rq->nr_running > 1) ++ return false; ++ ++ if (scx_enabled() && !scx_can_stop_tick(rq)) + return false; + + /* +@@ -1340,8 +1351,8 @@ static void set_load_weight(struct task_struct *p, bool update_load) + * SCHED_OTHER tasks have to update their load when changing their + * weight + */ +- if (update_load && p->sched_class == &fair_sched_class) +- reweight_task(p, &lw); ++ if (update_load && p->sched_class->reweight_task) ++ p->sched_class->reweight_task(task_rq(p), p, &lw); + else + p->se.load = lw; + } +@@ -2210,6 +2221,17 @@ inline int task_curr(const struct task_struct *p) + return cpu_curr(task_cpu(p)) == p; + } + ++/* ++ * ->switching_to() is called with the pi_lock and rq_lock held and must not ++ * mess with locking. ++ */ ++void check_class_changing(struct rq *rq, struct task_struct *p, ++ const struct sched_class *prev_class) ++{ ++ if (prev_class != p->sched_class && p->sched_class->switching_to) ++ p->sched_class->switching_to(rq, p); ++} ++ + /* + * switched_from, switched_to and prio_changed must _NOT_ drop rq->lock, + * use the balance_callback list if you want balancing. +@@ -2217,9 +2239,9 @@ inline int task_curr(const struct task_struct *p) + * this means any call to check_class_changed() must be followed by a call to + * balance_callback(). + */ +-static inline void check_class_changed(struct rq *rq, struct task_struct *p, +- const struct sched_class *prev_class, +- int oldprio) ++void check_class_changed(struct rq *rq, struct task_struct *p, ++ const struct sched_class *prev_class, ++ int oldprio) + { + if (prev_class != p->sched_class) { + if (prev_class->switched_from) +@@ -3982,6 +4004,15 @@ bool cpus_share_resources(int this_cpu, int that_cpu) + + static inline bool ttwu_queue_cond(struct task_struct *p, int cpu) + { ++ /* ++ * The BPF scheduler may depend on select_task_rq() being invoked during ++ * wakeups. In addition, @p may end up executing on a different CPU ++ * regardless of what happens in the wakeup path making the ttwu_queue ++ * optimization less meaningful. Skip if on SCX. ++ */ ++ if (task_on_scx(p)) ++ return false; ++ + /* + * Do not complicate things with the async wake_list while the CPU is + * in hotplug state. +@@ -4549,6 +4580,10 @@ static void __sched_fork(unsigned long clone_flags, struct task_struct *p) + p->rt.on_rq = 0; + p->rt.on_list = 0; + ++#ifdef CONFIG_SCHED_CLASS_EXT ++ init_scx_entity(&p->scx); ++#endif ++ + #ifdef CONFIG_PREEMPT_NOTIFIERS + INIT_HLIST_HEAD(&p->preempt_notifiers); + #endif +@@ -4789,10 +4824,18 @@ int sched_fork(unsigned long clone_flags, struct task_struct *p) + + if (dl_prio(p->prio)) + return -EAGAIN; +- else if (rt_prio(p->prio)) ++ ++ scx_pre_fork(p); ++ ++ if (rt_prio(p->prio)) { + p->sched_class = &rt_sched_class; +- else ++#ifdef CONFIG_SCHED_CLASS_EXT ++ } else if (task_should_scx(p)) { ++ p->sched_class = &ext_sched_class; ++#endif ++ } else { + p->sched_class = &fair_sched_class; ++ } + + init_entity_runnable_average(&p->se); + +@@ -4812,7 +4855,7 @@ int sched_fork(unsigned long clone_flags, struct task_struct *p) + return 0; + } + +-void sched_cgroup_fork(struct task_struct *p, struct kernel_clone_args *kargs) ++int sched_cgroup_fork(struct task_struct *p, struct kernel_clone_args *kargs) + { + unsigned long flags; + +@@ -4839,11 +4882,19 @@ void sched_cgroup_fork(struct task_struct *p, struct kernel_clone_args *kargs) + if (p->sched_class->task_fork) + p->sched_class->task_fork(p); + raw_spin_unlock_irqrestore(&p->pi_lock, flags); ++ ++ return scx_fork(p); ++} ++ ++void sched_cancel_fork(struct task_struct *p) ++{ ++ scx_cancel_fork(p); + } + + void sched_post_fork(struct task_struct *p) + { + uclamp_post_fork(p); ++ scx_post_fork(p); + } + + unsigned long to_ratio(u64 period, u64 runtime) +@@ -5685,6 +5736,7 @@ void sched_tick(void) + calc_global_load_tick(rq); + sched_core_tick(rq); + task_tick_mm_cid(rq, curr); ++ scx_tick(rq); + + rq_unlock(rq, &rf); + +@@ -5697,8 +5749,10 @@ void sched_tick(void) + wq_worker_tick(curr); + + #ifdef CONFIG_SMP +- rq->idle_balance = idle_cpu(cpu); +- sched_balance_trigger(rq); ++ if (!scx_switched_all()) { ++ rq->idle_balance = idle_cpu(cpu); ++ sched_balance_trigger(rq); ++ } + #endif + } + +@@ -5989,7 +6043,19 @@ static void put_prev_task_balance(struct rq *rq, struct task_struct *prev, + struct rq_flags *rf) + { + #ifdef CONFIG_SMP ++ const struct sched_class *start_class = prev->sched_class; + const struct sched_class *class; ++ ++#ifdef CONFIG_SCHED_CLASS_EXT ++ /* ++ * SCX requires a balance() call before every pick_next_task() including ++ * when waking up from SCHED_IDLE. If @start_class is below SCX, start ++ * from SCX instead. ++ */ ++ if (sched_class_above(&ext_sched_class, start_class)) ++ start_class = &ext_sched_class; ++#endif ++ + /* + * We must do the balancing pass before put_prev_task(), such + * that when we release the rq->lock the task is in the same +@@ -5998,7 +6064,7 @@ static void put_prev_task_balance(struct rq *rq, struct task_struct *prev, + * We can terminate the balance pass as soon as we know there is + * a runnable task of @class priority or higher. + */ +- for_class_range(class, prev->sched_class, &idle_sched_class) { ++ for_active_class_range(class, start_class, &idle_sched_class) { + if (class->balance(rq, prev, rf)) + break; + } +@@ -6016,6 +6082,9 @@ __pick_next_task(struct rq *rq, struct task_struct *prev, struct rq_flags *rf) + const struct sched_class *class; + struct task_struct *p; + ++ if (scx_enabled()) ++ goto restart; ++ + /* + * Optimization: we know that if all tasks are in the fair class we can + * call that function directly, but only if the @prev task wasn't of a +@@ -6056,10 +6125,15 @@ __pick_next_task(struct rq *rq, struct task_struct *prev, struct rq_flags *rf) + if (prev->dl_server) + prev->dl_server = NULL; + +- for_each_class(class) { ++ for_each_active_class(class) { + p = class->pick_next_task(rq); +- if (p) ++ if (p) { ++ const struct sched_class *prev_class = prev->sched_class; ++ ++ if (class != prev_class && prev_class->switch_class) ++ prev_class->switch_class(rq, p); + return p; ++ } + } + + BUG(); /* The idle class should always have a runnable task. */ +@@ -6089,7 +6163,7 @@ static inline struct task_struct *pick_task(struct rq *rq) + const struct sched_class *class; + struct task_struct *p; + +- for_each_class(class) { ++ for_each_active_class(class) { + p = class->pick_task(rq); + if (p) + return p; +@@ -7080,12 +7154,16 @@ int default_wake_function(wait_queue_entry_t *curr, unsigned mode, int wake_flag + } + EXPORT_SYMBOL(default_wake_function); + +-static void __setscheduler_prio(struct task_struct *p, int prio) ++void __setscheduler_prio(struct task_struct *p, int prio) + { + if (dl_prio(prio)) + p->sched_class = &dl_sched_class; + else if (rt_prio(prio)) + p->sched_class = &rt_sched_class; ++#ifdef CONFIG_SCHED_CLASS_EXT ++ else if (task_should_scx(p)) ++ p->sched_class = &ext_sched_class; ++#endif + else + p->sched_class = &fair_sched_class; + +@@ -7246,6 +7324,7 @@ void rt_mutex_setprio(struct task_struct *p, struct task_struct *pi_task) + } + + __setscheduler_prio(p, prio); ++ check_class_changing(rq, p, prev_class); + + if (queued) + enqueue_task(rq, p, queue_flag); +@@ -7467,6 +7546,25 @@ int sched_core_idle_cpu(int cpu) + #endif + + #ifdef CONFIG_SMP ++/* ++ * Load avg and utiliztion metrics need to be updated periodically and before ++ * consumption. This function updates the metrics for all subsystems except for ++ * the fair class. @rq must be locked and have its clock updated. ++ */ ++bool update_other_load_avgs(struct rq *rq) ++{ ++ u64 now = rq_clock_pelt(rq); ++ const struct sched_class *curr_class = rq->curr->sched_class; ++ unsigned long hw_pressure = arch_scale_hw_pressure(cpu_of(rq)); ++ ++ lockdep_assert_rq_held(rq); ++ ++ return update_rt_rq_load_avg(now, rq, curr_class == &rt_sched_class) | ++ update_dl_rq_load_avg(now, rq, curr_class == &dl_sched_class) | ++ update_hw_load_avg(now, rq, hw_pressure) | ++ update_irq_load_avg(rq, 0); ++} ++ + /* + * This function computes an effective utilization for the given CPU, to be + * used for frequency selection given the linear relation: f = u * f_max. +@@ -7789,6 +7887,10 @@ static int __sched_setscheduler(struct task_struct *p, + goto unlock; + } + ++ retval = scx_check_setscheduler(p, policy); ++ if (retval) ++ goto unlock; ++ + /* + * If not changing anything there's no need to proceed further, + * but store a possible modification of reset_on_fork. +@@ -7891,6 +7993,7 @@ static int __sched_setscheduler(struct task_struct *p, + __setscheduler_prio(p, newprio); + } + __setscheduler_uclamp(p, attr); ++ check_class_changing(rq, p, prev_class); + + if (queued) { + /* +@@ -9066,6 +9169,7 @@ SYSCALL_DEFINE1(sched_get_priority_max, int, policy) + case SCHED_NORMAL: + case SCHED_BATCH: + case SCHED_IDLE: ++ case SCHED_EXT: + ret = 0; + break; + } +@@ -9093,6 +9197,7 @@ SYSCALL_DEFINE1(sched_get_priority_min, int, policy) + case SCHED_NORMAL: + case SCHED_BATCH: + case SCHED_IDLE: ++ case SCHED_EXT: + ret = 0; + } + return ret; +@@ -9188,6 +9293,7 @@ void sched_show_task(struct task_struct *p) + + print_worker_info(KERN_INFO, p); + print_stop_info(KERN_INFO, p); ++ print_scx_info(KERN_INFO, p); + show_stack(p, NULL, KERN_INFO); + put_task_stack(p); + } +@@ -9680,6 +9786,8 @@ int sched_cpu_activate(unsigned int cpu) + cpuset_cpu_active(); + } + ++ scx_rq_activate(rq); ++ + /* + * Put the rq online, if not already. This happens: + * +@@ -9740,6 +9848,8 @@ int sched_cpu_deactivate(unsigned int cpu) + } + rq_unlock_irqrestore(rq, &rf); + ++ scx_rq_deactivate(rq); ++ + #ifdef CONFIG_SCHED_SMT + /* + * When going down, decrement the number of cores with SMT present. +@@ -9923,11 +10033,15 @@ void __init sched_init(void) + int i; + + /* Make sure the linker didn't screw up */ +- BUG_ON(&idle_sched_class != &fair_sched_class + 1 || +- &fair_sched_class != &rt_sched_class + 1 || +- &rt_sched_class != &dl_sched_class + 1); + #ifdef CONFIG_SMP +- BUG_ON(&dl_sched_class != &stop_sched_class + 1); ++ BUG_ON(!sched_class_above(&stop_sched_class, &dl_sched_class)); ++#endif ++ BUG_ON(!sched_class_above(&dl_sched_class, &rt_sched_class)); ++ BUG_ON(!sched_class_above(&rt_sched_class, &fair_sched_class)); ++ BUG_ON(!sched_class_above(&fair_sched_class, &idle_sched_class)); ++#ifdef CONFIG_SCHED_CLASS_EXT ++ BUG_ON(!sched_class_above(&fair_sched_class, &ext_sched_class)); ++ BUG_ON(!sched_class_above(&ext_sched_class, &idle_sched_class)); + #endif + + wait_bit_init(); +@@ -10096,6 +10210,7 @@ void __init sched_init(void) + balance_push_set(smp_processor_id(), false); + #endif + init_sched_fair_class(); ++ init_sched_ext_class(); + + psi_init(); + +@@ -10522,11 +10637,6 @@ void sched_move_task(struct task_struct *tsk) + } + } + +-static inline struct task_group *css_tg(struct cgroup_subsys_state *css) +-{ +- return css ? container_of(css, struct task_group, css) : NULL; +-} +- + static struct cgroup_subsys_state * + cpu_cgroup_css_alloc(struct cgroup_subsys_state *parent_css) + { +@@ -11293,29 +11403,27 @@ static int cpu_local_stat_show(struct seq_file *sf, + } + + #ifdef CONFIG_FAIR_GROUP_SCHED ++ ++static unsigned long tg_weight(struct task_group *tg) ++{ ++ return scale_load_down(tg->shares); ++} ++ + static u64 cpu_weight_read_u64(struct cgroup_subsys_state *css, + struct cftype *cft) + { +- struct task_group *tg = css_tg(css); +- u64 weight = scale_load_down(tg->shares); +- +- return DIV_ROUND_CLOSEST_ULL(weight * CGROUP_WEIGHT_DFL, 1024); ++ return sched_weight_to_cgroup(tg_weight(css_tg(css))); + } + + static int cpu_weight_write_u64(struct cgroup_subsys_state *css, +- struct cftype *cft, u64 weight) ++ struct cftype *cft, u64 cgrp_weight) + { +- /* +- * cgroup weight knobs should use the common MIN, DFL and MAX +- * values which are 1, 100 and 10000 respectively. While it loses +- * a bit of range on both ends, it maps pretty well onto the shares +- * value used by scheduler and the round-trip conversions preserve +- * the original value over the entire range. +- */ +- if (weight < CGROUP_WEIGHT_MIN || weight > CGROUP_WEIGHT_MAX) ++ unsigned long weight; ++ ++ if (cgrp_weight < CGROUP_WEIGHT_MIN || cgrp_weight > CGROUP_WEIGHT_MAX) + return -ERANGE; + +- weight = DIV_ROUND_CLOSEST_ULL(weight * 1024, CGROUP_WEIGHT_DFL); ++ weight = sched_weight_from_cgroup(cgrp_weight); + + return sched_group_set_shares(css_tg(css), scale_load(weight)); + } +@@ -11323,7 +11431,7 @@ static int cpu_weight_write_u64(struct cgroup_subsys_state *css, + static s64 cpu_weight_nice_read_s64(struct cgroup_subsys_state *css, + struct cftype *cft) + { +- unsigned long weight = scale_load_down(css_tg(css)->shares); ++ unsigned long weight = tg_weight(css_tg(css)); + int last_delta = INT_MAX; + int prio, delta; + +@@ -12064,3 +12172,38 @@ void sched_mm_cid_fork(struct task_struct *t) + t->mm_cid_active = 1; + } + #endif ++ ++#ifdef CONFIG_SCHED_CLASS_EXT ++void sched_deq_and_put_task(struct task_struct *p, int queue_flags, ++ struct sched_enq_and_set_ctx *ctx) ++{ ++ struct rq *rq = task_rq(p); ++ ++ lockdep_assert_rq_held(rq); ++ ++ *ctx = (struct sched_enq_and_set_ctx){ ++ .p = p, ++ .queue_flags = queue_flags, ++ .queued = task_on_rq_queued(p), ++ .running = task_current(rq, p), ++ }; ++ ++ update_rq_clock(rq); ++ if (ctx->queued) ++ dequeue_task(rq, p, queue_flags | DEQUEUE_NOCLOCK); ++ if (ctx->running) ++ put_prev_task(rq, p); ++} ++ ++void sched_enq_and_set_task(struct sched_enq_and_set_ctx *ctx) ++{ ++ struct rq *rq = task_rq(ctx->p); ++ ++ lockdep_assert_rq_held(rq); ++ ++ if (ctx->queued) ++ enqueue_task(rq, ctx->p, ctx->queue_flags | ENQUEUE_NOCLOCK); ++ if (ctx->running) ++ set_next_task(rq, ctx->p); ++} ++#endif /* CONFIG_SCHED_CLASS_EXT */ +diff --git a/kernel/sched/cpufreq_schedutil.c b/kernel/sched/cpufreq_schedutil.c +index eece6244f9d2..e683e5d08daa 100644 +--- a/kernel/sched/cpufreq_schedutil.c ++++ b/kernel/sched/cpufreq_schedutil.c +@@ -197,8 +197,10 @@ unsigned long sugov_effective_cpu_perf(int cpu, unsigned long actual, + + static void sugov_get_util(struct sugov_cpu *sg_cpu, unsigned long boost) + { +- unsigned long min, max, util = cpu_util_cfs_boost(sg_cpu->cpu); ++ unsigned long min, max, util = scx_cpuperf_target(sg_cpu->cpu); + ++ if (!scx_switched_all()) ++ util += cpu_util_cfs_boost(sg_cpu->cpu); + util = effective_cpu_util(sg_cpu->cpu, util, &min, &max); + util = max(util, boost); + sg_cpu->bw_min = min; +@@ -325,16 +327,35 @@ static unsigned long sugov_iowait_apply(struct sugov_cpu *sg_cpu, u64 time, + } + + #ifdef CONFIG_NO_HZ_COMMON +-static bool sugov_cpu_is_busy(struct sugov_cpu *sg_cpu) ++static bool sugov_hold_freq(struct sugov_cpu *sg_cpu) + { +- unsigned long idle_calls = tick_nohz_get_idle_calls_cpu(sg_cpu->cpu); +- bool ret = idle_calls == sg_cpu->saved_idle_calls; ++ unsigned long idle_calls; ++ bool ret; ++ ++ /* ++ * The heuristics in this function is for the fair class. For SCX, the ++ * performance target comes directly from the BPF scheduler. Let's just ++ * follow it. ++ */ ++ if (scx_switched_all()) ++ return false; ++ ++ /* if capped by uclamp_max, always update to be in compliance */ ++ if (uclamp_rq_is_capped(cpu_rq(sg_cpu->cpu))) ++ return false; ++ ++ /* ++ * Maintain the frequency if the CPU has not been idle recently, as ++ * reduction is likely to be premature. ++ */ ++ idle_calls = tick_nohz_get_idle_calls_cpu(sg_cpu->cpu); ++ ret = idle_calls == sg_cpu->saved_idle_calls; + + sg_cpu->saved_idle_calls = idle_calls; + return ret; + } + #else +-static inline bool sugov_cpu_is_busy(struct sugov_cpu *sg_cpu) { return false; } ++static inline bool sugov_hold_freq(struct sugov_cpu *sg_cpu) { return false; } + #endif /* CONFIG_NO_HZ_COMMON */ + + /* +@@ -382,14 +403,8 @@ static void sugov_update_single_freq(struct update_util_data *hook, u64 time, + return; + + next_f = get_next_freq(sg_policy, sg_cpu->util, max_cap); +- /* +- * Do not reduce the frequency if the CPU has not been idle +- * recently, as the reduction is likely to be premature then. +- * +- * Except when the rq is capped by uclamp_max. +- */ +- if (!uclamp_rq_is_capped(cpu_rq(sg_cpu->cpu)) && +- sugov_cpu_is_busy(sg_cpu) && next_f < sg_policy->next_freq && ++ ++ if (sugov_hold_freq(sg_cpu) && next_f < sg_policy->next_freq && + !sg_policy->need_freq_update) { + next_f = sg_policy->next_freq; + +@@ -436,14 +451,7 @@ static void sugov_update_single_perf(struct update_util_data *hook, u64 time, + if (!sugov_update_single_common(sg_cpu, time, max_cap, flags)) + return; + +- /* +- * Do not reduce the target performance level if the CPU has not been +- * idle recently, as the reduction is likely to be premature then. +- * +- * Except when the rq is capped by uclamp_max. +- */ +- if (!uclamp_rq_is_capped(cpu_rq(sg_cpu->cpu)) && +- sugov_cpu_is_busy(sg_cpu) && sg_cpu->util < prev_util) ++ if (sugov_hold_freq(sg_cpu) && sg_cpu->util < prev_util) + sg_cpu->util = prev_util; + + cpufreq_driver_adjust_perf(sg_cpu->cpu, sg_cpu->bw_min, +diff --git a/kernel/sched/debug.c b/kernel/sched/debug.c +index c1eb9a1afd13..c057ef46c5f8 100644 +--- a/kernel/sched/debug.c ++++ b/kernel/sched/debug.c +@@ -1090,6 +1090,9 @@ void proc_sched_show_task(struct task_struct *p, struct pid_namespace *ns, + P(dl.runtime); + P(dl.deadline); + } ++#ifdef CONFIG_SCHED_CLASS_EXT ++ __PS("ext.enabled", task_on_scx(p)); ++#endif + #undef PN_SCHEDSTAT + #undef P_SCHEDSTAT + +diff --git a/kernel/sched/ext.c b/kernel/sched/ext.c +new file mode 100644 +index 000000000000..0dac88d0e578 +--- /dev/null ++++ b/kernel/sched/ext.c +@@ -0,0 +1,6532 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * BPF extensible scheduler class: Documentation/scheduler/sched-ext.rst ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#define SCX_OP_IDX(op) (offsetof(struct sched_ext_ops, op) / sizeof(void (*)(void))) ++ ++enum scx_consts { ++ SCX_DSP_DFL_MAX_BATCH = 32, ++ SCX_DSP_MAX_LOOPS = 32, ++ SCX_WATCHDOG_MAX_TIMEOUT = 30 * HZ, ++ ++ SCX_EXIT_BT_LEN = 64, ++ SCX_EXIT_MSG_LEN = 1024, ++ SCX_EXIT_DUMP_DFL_LEN = 32768, ++ ++ SCX_CPUPERF_ONE = SCHED_CAPACITY_SCALE, ++}; ++ ++enum scx_exit_kind { ++ SCX_EXIT_NONE, ++ SCX_EXIT_DONE, ++ ++ SCX_EXIT_UNREG = 64, /* user-space initiated unregistration */ ++ SCX_EXIT_UNREG_BPF, /* BPF-initiated unregistration */ ++ SCX_EXIT_UNREG_KERN, /* kernel-initiated unregistration */ ++ SCX_EXIT_SYSRQ, /* requested by 'S' sysrq */ ++ ++ SCX_EXIT_ERROR = 1024, /* runtime error, error msg contains details */ ++ SCX_EXIT_ERROR_BPF, /* ERROR but triggered through scx_bpf_error() */ ++ SCX_EXIT_ERROR_STALL, /* watchdog detected stalled runnable tasks */ ++}; ++ ++/* ++ * An exit code can be specified when exiting with scx_bpf_exit() or ++ * scx_ops_exit(), corresponding to exit_kind UNREG_BPF and UNREG_KERN ++ * respectively. The codes are 64bit of the format: ++ * ++ * Bits: [63 .. 48 47 .. 32 31 .. 0] ++ * [ SYS ACT ] [ SYS RSN ] [ USR ] ++ * ++ * SYS ACT: System-defined exit actions ++ * SYS RSN: System-defined exit reasons ++ * USR : User-defined exit codes and reasons ++ * ++ * Using the above, users may communicate intention and context by ORing system ++ * actions and/or system reasons with a user-defined exit code. ++ */ ++enum scx_exit_code { ++ /* Reasons */ ++ SCX_ECODE_RSN_HOTPLUG = 1LLU << 32, ++ ++ /* Actions */ ++ SCX_ECODE_ACT_RESTART = 1LLU << 48, ++}; ++ ++/* ++ * scx_exit_info is passed to ops.exit() to describe why the BPF scheduler is ++ * being disabled. ++ */ ++struct scx_exit_info { ++ /* %SCX_EXIT_* - broad category of the exit reason */ ++ enum scx_exit_kind kind; ++ ++ /* exit code if gracefully exiting */ ++ s64 exit_code; ++ ++ /* textual representation of the above */ ++ const char *reason; ++ ++ /* backtrace if exiting due to an error */ ++ unsigned long *bt; ++ u32 bt_len; ++ ++ /* informational message */ ++ char *msg; ++ ++ /* debug dump */ ++ char *dump; ++}; ++ ++/* sched_ext_ops.flags */ ++enum scx_ops_flags { ++ /* ++ * Keep built-in idle tracking even if ops.update_idle() is implemented. ++ */ ++ SCX_OPS_KEEP_BUILTIN_IDLE = 1LLU << 0, ++ ++ /* ++ * By default, if there are no other task to run on the CPU, ext core ++ * keeps running the current task even after its slice expires. If this ++ * flag is specified, such tasks are passed to ops.enqueue() with ++ * %SCX_ENQ_LAST. See the comment above %SCX_ENQ_LAST for more info. ++ */ ++ SCX_OPS_ENQ_LAST = 1LLU << 1, ++ ++ /* ++ * An exiting task may schedule after PF_EXITING is set. In such cases, ++ * bpf_task_from_pid() may not be able to find the task and if the BPF ++ * scheduler depends on pid lookup for dispatching, the task will be ++ * lost leading to various issues including RCU grace period stalls. ++ * ++ * To mask this problem, by default, unhashed tasks are automatically ++ * dispatched to the local DSQ on enqueue. If the BPF scheduler doesn't ++ * depend on pid lookups and wants to handle these tasks directly, the ++ * following flag can be used. ++ */ ++ SCX_OPS_ENQ_EXITING = 1LLU << 2, ++ ++ /* ++ * If set, only tasks with policy set to SCHED_EXT are attached to ++ * sched_ext. If clear, SCHED_NORMAL tasks are also included. ++ */ ++ SCX_OPS_SWITCH_PARTIAL = 1LLU << 3, ++ ++ SCX_OPS_ALL_FLAGS = SCX_OPS_KEEP_BUILTIN_IDLE | ++ SCX_OPS_ENQ_LAST | ++ SCX_OPS_ENQ_EXITING | ++ SCX_OPS_SWITCH_PARTIAL, ++}; ++ ++/* argument container for ops.init_task() */ ++struct scx_init_task_args { ++ /* ++ * Set if ops.init_task() is being invoked on the fork path, as opposed ++ * to the scheduler transition path. ++ */ ++ bool fork; ++}; ++ ++/* argument container for ops.exit_task() */ ++struct scx_exit_task_args { ++ /* Whether the task exited before running on sched_ext. */ ++ bool cancelled; ++}; ++ ++enum scx_cpu_preempt_reason { ++ /* next task is being scheduled by &sched_class_rt */ ++ SCX_CPU_PREEMPT_RT, ++ /* next task is being scheduled by &sched_class_dl */ ++ SCX_CPU_PREEMPT_DL, ++ /* next task is being scheduled by &sched_class_stop */ ++ SCX_CPU_PREEMPT_STOP, ++ /* unknown reason for SCX being preempted */ ++ SCX_CPU_PREEMPT_UNKNOWN, ++}; ++ ++/* ++ * Argument container for ops->cpu_acquire(). Currently empty, but may be ++ * expanded in the future. ++ */ ++struct scx_cpu_acquire_args {}; ++ ++/* argument container for ops->cpu_release() */ ++struct scx_cpu_release_args { ++ /* the reason the CPU was preempted */ ++ enum scx_cpu_preempt_reason reason; ++ ++ /* the task that's going to be scheduled on the CPU */ ++ struct task_struct *task; ++}; ++ ++/* ++ * Informational context provided to dump operations. ++ */ ++struct scx_dump_ctx { ++ enum scx_exit_kind kind; ++ s64 exit_code; ++ const char *reason; ++ u64 at_ns; ++ u64 at_jiffies; ++}; ++ ++/** ++ * struct sched_ext_ops - Operation table for BPF scheduler implementation ++ * ++ * Userland can implement an arbitrary scheduling policy by implementing and ++ * loading operations in this table. ++ */ ++struct sched_ext_ops { ++ /** ++ * select_cpu - Pick the target CPU for a task which is being woken up ++ * @p: task being woken up ++ * @prev_cpu: the cpu @p was on before sleeping ++ * @wake_flags: SCX_WAKE_* ++ * ++ * Decision made here isn't final. @p may be moved to any CPU while it ++ * is getting dispatched for execution later. However, as @p is not on ++ * the rq at this point, getting the eventual execution CPU right here ++ * saves a small bit of overhead down the line. ++ * ++ * If an idle CPU is returned, the CPU is kicked and will try to ++ * dispatch. While an explicit custom mechanism can be added, ++ * select_cpu() serves as the default way to wake up idle CPUs. ++ * ++ * @p may be dispatched directly by calling scx_bpf_dispatch(). If @p ++ * is dispatched, the ops.enqueue() callback will be skipped. Finally, ++ * if @p is dispatched to SCX_DSQ_LOCAL, it will be dispatched to the ++ * local DSQ of whatever CPU is returned by this callback. ++ */ ++ s32 (*select_cpu)(struct task_struct *p, s32 prev_cpu, u64 wake_flags); ++ ++ /** ++ * enqueue - Enqueue a task on the BPF scheduler ++ * @p: task being enqueued ++ * @enq_flags: %SCX_ENQ_* ++ * ++ * @p is ready to run. Dispatch directly by calling scx_bpf_dispatch() ++ * or enqueue on the BPF scheduler. If not directly dispatched, the bpf ++ * scheduler owns @p and if it fails to dispatch @p, the task will ++ * stall. ++ * ++ * If @p was dispatched from ops.select_cpu(), this callback is ++ * skipped. ++ */ ++ void (*enqueue)(struct task_struct *p, u64 enq_flags); ++ ++ /** ++ * dequeue - Remove a task from the BPF scheduler ++ * @p: task being dequeued ++ * @deq_flags: %SCX_DEQ_* ++ * ++ * Remove @p from the BPF scheduler. This is usually called to isolate ++ * the task while updating its scheduling properties (e.g. priority). ++ * ++ * The ext core keeps track of whether the BPF side owns a given task or ++ * not and can gracefully ignore spurious dispatches from BPF side, ++ * which makes it safe to not implement this method. However, depending ++ * on the scheduling logic, this can lead to confusing behaviors - e.g. ++ * scheduling position not being updated across a priority change. ++ */ ++ void (*dequeue)(struct task_struct *p, u64 deq_flags); ++ ++ /** ++ * dispatch - Dispatch tasks from the BPF scheduler and/or consume DSQs ++ * @cpu: CPU to dispatch tasks for ++ * @prev: previous task being switched out ++ * ++ * Called when a CPU's local dsq is empty. The operation should dispatch ++ * one or more tasks from the BPF scheduler into the DSQs using ++ * scx_bpf_dispatch() and/or consume user DSQs into the local DSQ using ++ * scx_bpf_consume(). ++ * ++ * The maximum number of times scx_bpf_dispatch() can be called without ++ * an intervening scx_bpf_consume() is specified by ++ * ops.dispatch_max_batch. See the comments on top of the two functions ++ * for more details. ++ * ++ * When not %NULL, @prev is an SCX task with its slice depleted. If ++ * @prev is still runnable as indicated by set %SCX_TASK_QUEUED in ++ * @prev->scx.flags, it is not enqueued yet and will be enqueued after ++ * ops.dispatch() returns. To keep executing @prev, return without ++ * dispatching or consuming any tasks. Also see %SCX_OPS_ENQ_LAST. ++ */ ++ void (*dispatch)(s32 cpu, struct task_struct *prev); ++ ++ /** ++ * tick - Periodic tick ++ * @p: task running currently ++ * ++ * This operation is called every 1/HZ seconds on CPUs which are ++ * executing an SCX task. Setting @p->scx.slice to 0 will trigger an ++ * immediate dispatch cycle on the CPU. ++ */ ++ void (*tick)(struct task_struct *p); ++ ++ /** ++ * runnable - A task is becoming runnable on its associated CPU ++ * @p: task becoming runnable ++ * @enq_flags: %SCX_ENQ_* ++ * ++ * This and the following three functions can be used to track a task's ++ * execution state transitions. A task becomes ->runnable() on a CPU, ++ * and then goes through one or more ->running() and ->stopping() pairs ++ * as it runs on the CPU, and eventually becomes ->quiescent() when it's ++ * done running on the CPU. ++ * ++ * @p is becoming runnable on the CPU because it's ++ * ++ * - waking up (%SCX_ENQ_WAKEUP) ++ * - being moved from another CPU ++ * - being restored after temporarily taken off the queue for an ++ * attribute change. ++ * ++ * This and ->enqueue() are related but not coupled. This operation ++ * notifies @p's state transition and may not be followed by ->enqueue() ++ * e.g. when @p is being dispatched to a remote CPU, or when @p is ++ * being enqueued on a CPU experiencing a hotplug event. Likewise, a ++ * task may be ->enqueue()'d without being preceded by this operation ++ * e.g. after exhausting its slice. ++ */ ++ void (*runnable)(struct task_struct *p, u64 enq_flags); ++ ++ /** ++ * running - A task is starting to run on its associated CPU ++ * @p: task starting to run ++ * ++ * See ->runnable() for explanation on the task state notifiers. ++ */ ++ void (*running)(struct task_struct *p); ++ ++ /** ++ * stopping - A task is stopping execution ++ * @p: task stopping to run ++ * @runnable: is task @p still runnable? ++ * ++ * See ->runnable() for explanation on the task state notifiers. If ++ * !@runnable, ->quiescent() will be invoked after this operation ++ * returns. ++ */ ++ void (*stopping)(struct task_struct *p, bool runnable); ++ ++ /** ++ * quiescent - A task is becoming not runnable on its associated CPU ++ * @p: task becoming not runnable ++ * @deq_flags: %SCX_DEQ_* ++ * ++ * See ->runnable() for explanation on the task state notifiers. ++ * ++ * @p is becoming quiescent on the CPU because it's ++ * ++ * - sleeping (%SCX_DEQ_SLEEP) ++ * - being moved to another CPU ++ * - being temporarily taken off the queue for an attribute change ++ * (%SCX_DEQ_SAVE) ++ * ++ * This and ->dequeue() are related but not coupled. This operation ++ * notifies @p's state transition and may not be preceded by ->dequeue() ++ * e.g. when @p is being dispatched to a remote CPU. ++ */ ++ void (*quiescent)(struct task_struct *p, u64 deq_flags); ++ ++ /** ++ * yield - Yield CPU ++ * @from: yielding task ++ * @to: optional yield target task ++ * ++ * If @to is NULL, @from is yielding the CPU to other runnable tasks. ++ * The BPF scheduler should ensure that other available tasks are ++ * dispatched before the yielding task. Return value is ignored in this ++ * case. ++ * ++ * If @to is not-NULL, @from wants to yield the CPU to @to. If the bpf ++ * scheduler can implement the request, return %true; otherwise, %false. ++ */ ++ bool (*yield)(struct task_struct *from, struct task_struct *to); ++ ++ /** ++ * core_sched_before - Task ordering for core-sched ++ * @a: task A ++ * @b: task B ++ * ++ * Used by core-sched to determine the ordering between two tasks. See ++ * Documentation/admin-guide/hw-vuln/core-scheduling.rst for details on ++ * core-sched. ++ * ++ * Both @a and @b are runnable and may or may not currently be queued on ++ * the BPF scheduler. Should return %true if @a should run before @b. ++ * %false if there's no required ordering or @b should run before @a. ++ * ++ * If not specified, the default is ordering them according to when they ++ * became runnable. ++ */ ++ bool (*core_sched_before)(struct task_struct *a, struct task_struct *b); ++ ++ /** ++ * set_weight - Set task weight ++ * @p: task to set weight for ++ * @weight: new weight [1..10000] ++ * ++ * Update @p's weight to @weight. ++ */ ++ void (*set_weight)(struct task_struct *p, u32 weight); ++ ++ /** ++ * set_cpumask - Set CPU affinity ++ * @p: task to set CPU affinity for ++ * @cpumask: cpumask of cpus that @p can run on ++ * ++ * Update @p's CPU affinity to @cpumask. ++ */ ++ void (*set_cpumask)(struct task_struct *p, ++ const struct cpumask *cpumask); ++ ++ /** ++ * update_idle - Update the idle state of a CPU ++ * @cpu: CPU to udpate the idle state for ++ * @idle: whether entering or exiting the idle state ++ * ++ * This operation is called when @rq's CPU goes or leaves the idle ++ * state. By default, implementing this operation disables the built-in ++ * idle CPU tracking and the following helpers become unavailable: ++ * ++ * - scx_bpf_select_cpu_dfl() ++ * - scx_bpf_test_and_clear_cpu_idle() ++ * - scx_bpf_pick_idle_cpu() ++ * ++ * The user also must implement ops.select_cpu() as the default ++ * implementation relies on scx_bpf_select_cpu_dfl(). ++ * ++ * Specify the %SCX_OPS_KEEP_BUILTIN_IDLE flag to keep the built-in idle ++ * tracking. ++ */ ++ void (*update_idle)(s32 cpu, bool idle); ++ ++ /** ++ * cpu_acquire - A CPU is becoming available to the BPF scheduler ++ * @cpu: The CPU being acquired by the BPF scheduler. ++ * @args: Acquire arguments, see the struct definition. ++ * ++ * A CPU that was previously released from the BPF scheduler is now once ++ * again under its control. ++ */ ++ void (*cpu_acquire)(s32 cpu, struct scx_cpu_acquire_args *args); ++ ++ /** ++ * cpu_release - A CPU is taken away from the BPF scheduler ++ * @cpu: The CPU being released by the BPF scheduler. ++ * @args: Release arguments, see the struct definition. ++ * ++ * The specified CPU is no longer under the control of the BPF ++ * scheduler. This could be because it was preempted by a higher ++ * priority sched_class, though there may be other reasons as well. The ++ * caller should consult @args->reason to determine the cause. ++ */ ++ void (*cpu_release)(s32 cpu, struct scx_cpu_release_args *args); ++ ++ /** ++ * init_task - Initialize a task to run in a BPF scheduler ++ * @p: task to initialize for BPF scheduling ++ * @args: init arguments, see the struct definition ++ * ++ * Either we're loading a BPF scheduler or a new task is being forked. ++ * Initialize @p for BPF scheduling. This operation may block and can ++ * be used for allocations, and is called exactly once for a task. ++ * ++ * Return 0 for success, -errno for failure. An error return while ++ * loading will abort loading of the BPF scheduler. During a fork, it ++ * will abort that specific fork. ++ */ ++ s32 (*init_task)(struct task_struct *p, struct scx_init_task_args *args); ++ ++ /** ++ * exit_task - Exit a previously-running task from the system ++ * @p: task to exit ++ * ++ * @p is exiting or the BPF scheduler is being unloaded. Perform any ++ * necessary cleanup for @p. ++ */ ++ void (*exit_task)(struct task_struct *p, struct scx_exit_task_args *args); ++ ++ /** ++ * enable - Enable BPF scheduling for a task ++ * @p: task to enable BPF scheduling for ++ * ++ * Enable @p for BPF scheduling. enable() is called on @p any time it ++ * enters SCX, and is always paired with a matching disable(). ++ */ ++ void (*enable)(struct task_struct *p); ++ ++ /** ++ * disable - Disable BPF scheduling for a task ++ * @p: task to disable BPF scheduling for ++ * ++ * @p is exiting, leaving SCX or the BPF scheduler is being unloaded. ++ * Disable BPF scheduling for @p. A disable() call is always matched ++ * with a prior enable() call. ++ */ ++ void (*disable)(struct task_struct *p); ++ ++ /** ++ * dump - Dump BPF scheduler state on error ++ * @ctx: debug dump context ++ * ++ * Use scx_bpf_dump() to generate BPF scheduler specific debug dump. ++ */ ++ void (*dump)(struct scx_dump_ctx *ctx); ++ ++ /** ++ * dump_cpu - Dump BPF scheduler state for a CPU on error ++ * @ctx: debug dump context ++ * @cpu: CPU to generate debug dump for ++ * @idle: @cpu is currently idle without any runnable tasks ++ * ++ * Use scx_bpf_dump() to generate BPF scheduler specific debug dump for ++ * @cpu. If @idle is %true and this operation doesn't produce any ++ * output, @cpu is skipped for dump. ++ */ ++ void (*dump_cpu)(struct scx_dump_ctx *ctx, s32 cpu, bool idle); ++ ++ /** ++ * dump_task - Dump BPF scheduler state for a runnable task on error ++ * @ctx: debug dump context ++ * @p: runnable task to generate debug dump for ++ * ++ * Use scx_bpf_dump() to generate BPF scheduler specific debug dump for ++ * @p. ++ */ ++ void (*dump_task)(struct scx_dump_ctx *ctx, struct task_struct *p); ++ ++ /* ++ * All online ops must come before ops.cpu_online(). ++ */ ++ ++ /** ++ * cpu_online - A CPU became online ++ * @cpu: CPU which just came up ++ * ++ * @cpu just came online. @cpu will not call ops.enqueue() or ++ * ops.dispatch(), nor run tasks associated with other CPUs beforehand. ++ */ ++ void (*cpu_online)(s32 cpu); ++ ++ /** ++ * cpu_offline - A CPU is going offline ++ * @cpu: CPU which is going offline ++ * ++ * @cpu is going offline. @cpu will not call ops.enqueue() or ++ * ops.dispatch(), nor run tasks associated with other CPUs afterwards. ++ */ ++ void (*cpu_offline)(s32 cpu); ++ ++ /* ++ * All CPU hotplug ops must come before ops.init(). ++ */ ++ ++ /** ++ * init - Initialize the BPF scheduler ++ */ ++ s32 (*init)(void); ++ ++ /** ++ * exit - Clean up after the BPF scheduler ++ * @info: Exit info ++ */ ++ void (*exit)(struct scx_exit_info *info); ++ ++ /** ++ * dispatch_max_batch - Max nr of tasks that dispatch() can dispatch ++ */ ++ u32 dispatch_max_batch; ++ ++ /** ++ * flags - %SCX_OPS_* flags ++ */ ++ u64 flags; ++ ++ /** ++ * timeout_ms - The maximum amount of time, in milliseconds, that a ++ * runnable task should be able to wait before being scheduled. The ++ * maximum timeout may not exceed the default timeout of 30 seconds. ++ * ++ * Defaults to the maximum allowed timeout value of 30 seconds. ++ */ ++ u32 timeout_ms; ++ ++ /** ++ * exit_dump_len - scx_exit_info.dump buffer length. If 0, the default ++ * value of 32768 is used. ++ */ ++ u32 exit_dump_len; ++ ++ /** ++ * hotplug_seq - A sequence number that may be set by the scheduler to ++ * detect when a hotplug event has occurred during the loading process. ++ * If 0, no detection occurs. Otherwise, the scheduler will fail to ++ * load if the sequence number does not match @scx_hotplug_seq on the ++ * enable path. ++ */ ++ u64 hotplug_seq; ++ ++ /** ++ * name - BPF scheduler's name ++ * ++ * Must be a non-zero valid BPF object name including only isalnum(), ++ * '_' and '.' chars. Shows up in kernel.sched_ext_ops sysctl while the ++ * BPF scheduler is enabled. ++ */ ++ char name[SCX_OPS_NAME_LEN]; ++}; ++ ++enum scx_opi { ++ SCX_OPI_BEGIN = 0, ++ SCX_OPI_NORMAL_BEGIN = 0, ++ SCX_OPI_NORMAL_END = SCX_OP_IDX(cpu_online), ++ SCX_OPI_CPU_HOTPLUG_BEGIN = SCX_OP_IDX(cpu_online), ++ SCX_OPI_CPU_HOTPLUG_END = SCX_OP_IDX(init), ++ SCX_OPI_END = SCX_OP_IDX(init), ++}; ++ ++enum scx_wake_flags { ++ /* expose select WF_* flags as enums */ ++ SCX_WAKE_FORK = WF_FORK, ++ SCX_WAKE_TTWU = WF_TTWU, ++ SCX_WAKE_SYNC = WF_SYNC, ++}; ++ ++enum scx_enq_flags { ++ /* expose select ENQUEUE_* flags as enums */ ++ SCX_ENQ_WAKEUP = ENQUEUE_WAKEUP, ++ SCX_ENQ_HEAD = ENQUEUE_HEAD, ++ ++ /* high 32bits are SCX specific */ ++ ++ /* ++ * Set the following to trigger preemption when calling ++ * scx_bpf_dispatch() with a local dsq as the target. The slice of the ++ * current task is cleared to zero and the CPU is kicked into the ++ * scheduling path. Implies %SCX_ENQ_HEAD. ++ */ ++ SCX_ENQ_PREEMPT = 1LLU << 32, ++ ++ /* ++ * The task being enqueued was previously enqueued on the current CPU's ++ * %SCX_DSQ_LOCAL, but was removed from it in a call to the ++ * bpf_scx_reenqueue_local() kfunc. If bpf_scx_reenqueue_local() was ++ * invoked in a ->cpu_release() callback, and the task is again ++ * dispatched back to %SCX_LOCAL_DSQ by this current ->enqueue(), the ++ * task will not be scheduled on the CPU until at least the next invocation ++ * of the ->cpu_acquire() callback. ++ */ ++ SCX_ENQ_REENQ = 1LLU << 40, ++ ++ /* ++ * The task being enqueued is the only task available for the cpu. By ++ * default, ext core keeps executing such tasks but when ++ * %SCX_OPS_ENQ_LAST is specified, they're ops.enqueue()'d with the ++ * %SCX_ENQ_LAST flag set. ++ * ++ * If the BPF scheduler wants to continue executing the task, ++ * ops.enqueue() should dispatch the task to %SCX_DSQ_LOCAL immediately. ++ * If the task gets queued on a different dsq or the BPF side, the BPF ++ * scheduler is responsible for triggering a follow-up scheduling event. ++ * Otherwise, Execution may stall. ++ */ ++ SCX_ENQ_LAST = 1LLU << 41, ++ ++ /* high 8 bits are internal */ ++ __SCX_ENQ_INTERNAL_MASK = 0xffLLU << 56, ++ ++ SCX_ENQ_CLEAR_OPSS = 1LLU << 56, ++ SCX_ENQ_DSQ_PRIQ = 1LLU << 57, ++}; ++ ++enum scx_deq_flags { ++ /* expose select DEQUEUE_* flags as enums */ ++ SCX_DEQ_SLEEP = DEQUEUE_SLEEP, ++ ++ /* high 32bits are SCX specific */ ++ ++ /* ++ * The generic core-sched layer decided to execute the task even though ++ * it hasn't been dispatched yet. Dequeue from the BPF side. ++ */ ++ SCX_DEQ_CORE_SCHED_EXEC = 1LLU << 32, ++}; ++ ++enum scx_pick_idle_cpu_flags { ++ SCX_PICK_IDLE_CORE = 1LLU << 0, /* pick a CPU whose SMT siblings are also idle */ ++}; ++ ++enum scx_kick_flags { ++ /* ++ * Kick the target CPU if idle. Guarantees that the target CPU goes ++ * through at least one full scheduling cycle before going idle. If the ++ * target CPU can be determined to be currently not idle and going to go ++ * through a scheduling cycle before going idle, noop. ++ */ ++ SCX_KICK_IDLE = 1LLU << 0, ++ ++ /* ++ * Preempt the current task and execute the dispatch path. If the ++ * current task of the target CPU is an SCX task, its ->scx.slice is ++ * cleared to zero before the scheduling path is invoked so that the ++ * task expires and the dispatch path is invoked. ++ */ ++ SCX_KICK_PREEMPT = 1LLU << 1, ++ ++ /* ++ * Wait for the CPU to be rescheduled. The scx_bpf_kick_cpu() call will ++ * return after the target CPU finishes picking the next task. ++ */ ++ SCX_KICK_WAIT = 1LLU << 2, ++}; ++ ++enum scx_ops_enable_state { ++ SCX_OPS_PREPPING, ++ SCX_OPS_ENABLING, ++ SCX_OPS_ENABLED, ++ SCX_OPS_DISABLING, ++ SCX_OPS_DISABLED, ++}; ++ ++static const char *scx_ops_enable_state_str[] = { ++ [SCX_OPS_PREPPING] = "prepping", ++ [SCX_OPS_ENABLING] = "enabling", ++ [SCX_OPS_ENABLED] = "enabled", ++ [SCX_OPS_DISABLING] = "disabling", ++ [SCX_OPS_DISABLED] = "disabled", ++}; ++ ++/* ++ * sched_ext_entity->ops_state ++ * ++ * Used to track the task ownership between the SCX core and the BPF scheduler. ++ * State transitions look as follows: ++ * ++ * NONE -> QUEUEING -> QUEUED -> DISPATCHING ++ * ^ | | ++ * | v v ++ * \-------------------------------/ ++ * ++ * QUEUEING and DISPATCHING states can be waited upon. See wait_ops_state() call ++ * sites for explanations on the conditions being waited upon and why they are ++ * safe. Transitions out of them into NONE or QUEUED must store_release and the ++ * waiters should load_acquire. ++ * ++ * Tracking scx_ops_state enables sched_ext core to reliably determine whether ++ * any given task can be dispatched by the BPF scheduler at all times and thus ++ * relaxes the requirements on the BPF scheduler. This allows the BPF scheduler ++ * to try to dispatch any task anytime regardless of its state as the SCX core ++ * can safely reject invalid dispatches. ++ */ ++enum scx_ops_state { ++ SCX_OPSS_NONE, /* owned by the SCX core */ ++ SCX_OPSS_QUEUEING, /* in transit to the BPF scheduler */ ++ SCX_OPSS_QUEUED, /* owned by the BPF scheduler */ ++ SCX_OPSS_DISPATCHING, /* in transit back to the SCX core */ ++ ++ /* ++ * QSEQ brands each QUEUED instance so that, when dispatch races ++ * dequeue/requeue, the dispatcher can tell whether it still has a claim ++ * on the task being dispatched. ++ * ++ * As some 32bit archs can't do 64bit store_release/load_acquire, ++ * p->scx.ops_state is atomic_long_t which leaves 30 bits for QSEQ on ++ * 32bit machines. The dispatch race window QSEQ protects is very narrow ++ * and runs with IRQ disabled. 30 bits should be sufficient. ++ */ ++ SCX_OPSS_QSEQ_SHIFT = 2, ++}; ++ ++/* Use macros to ensure that the type is unsigned long for the masks */ ++#define SCX_OPSS_STATE_MASK ((1LU << SCX_OPSS_QSEQ_SHIFT) - 1) ++#define SCX_OPSS_QSEQ_MASK (~SCX_OPSS_STATE_MASK) ++ ++/* ++ * During exit, a task may schedule after losing its PIDs. When disabling the ++ * BPF scheduler, we need to be able to iterate tasks in every state to ++ * guarantee system safety. Maintain a dedicated task list which contains every ++ * task between its fork and eventual free. ++ */ ++static DEFINE_SPINLOCK(scx_tasks_lock); ++static LIST_HEAD(scx_tasks); ++ ++/* ops enable/disable */ ++static struct kthread_worker *scx_ops_helper; ++static DEFINE_MUTEX(scx_ops_enable_mutex); ++DEFINE_STATIC_KEY_FALSE(__scx_ops_enabled); ++DEFINE_STATIC_PERCPU_RWSEM(scx_fork_rwsem); ++static atomic_t scx_ops_enable_state_var = ATOMIC_INIT(SCX_OPS_DISABLED); ++static atomic_t scx_ops_bypass_depth = ATOMIC_INIT(0); ++static bool scx_switching_all; ++DEFINE_STATIC_KEY_FALSE(__scx_switched_all); ++ ++static struct sched_ext_ops scx_ops; ++static bool scx_warned_zero_slice; ++ ++static DEFINE_STATIC_KEY_FALSE(scx_ops_enq_last); ++static DEFINE_STATIC_KEY_FALSE(scx_ops_enq_exiting); ++static DEFINE_STATIC_KEY_FALSE(scx_ops_cpu_preempt); ++static DEFINE_STATIC_KEY_FALSE(scx_builtin_idle_enabled); ++ ++struct static_key_false scx_has_op[SCX_OPI_END] = ++ { [0 ... SCX_OPI_END-1] = STATIC_KEY_FALSE_INIT }; ++ ++static atomic_t scx_exit_kind = ATOMIC_INIT(SCX_EXIT_DONE); ++static struct scx_exit_info *scx_exit_info; ++ ++static atomic_long_t scx_nr_rejected = ATOMIC_LONG_INIT(0); ++static atomic_long_t scx_hotplug_seq = ATOMIC_LONG_INIT(0); ++ ++/* ++ * The maximum amount of time in jiffies that a task may be runnable without ++ * being scheduled on a CPU. If this timeout is exceeded, it will trigger ++ * scx_ops_error(). ++ */ ++static unsigned long scx_watchdog_timeout; ++ ++/* ++ * The last time the delayed work was run. This delayed work relies on ++ * ksoftirqd being able to run to service timer interrupts, so it's possible ++ * that this work itself could get wedged. To account for this, we check that ++ * it's not stalled in the timer tick, and trigger an error if it is. ++ */ ++static unsigned long scx_watchdog_timestamp = INITIAL_JIFFIES; ++ ++static struct delayed_work scx_watchdog_work; ++ ++/* idle tracking */ ++#ifdef CONFIG_SMP ++#ifdef CONFIG_CPUMASK_OFFSTACK ++#define CL_ALIGNED_IF_ONSTACK ++#else ++#define CL_ALIGNED_IF_ONSTACK __cacheline_aligned_in_smp ++#endif ++ ++static struct { ++ cpumask_var_t cpu; ++ cpumask_var_t smt; ++} idle_masks CL_ALIGNED_IF_ONSTACK; ++ ++#endif /* CONFIG_SMP */ ++ ++/* for %SCX_KICK_WAIT */ ++static unsigned long __percpu *scx_kick_cpus_pnt_seqs; ++ ++/* ++ * Direct dispatch marker. ++ * ++ * Non-NULL values are used for direct dispatch from enqueue path. A valid ++ * pointer points to the task currently being enqueued. An ERR_PTR value is used ++ * to indicate that direct dispatch has already happened. ++ */ ++static DEFINE_PER_CPU(struct task_struct *, direct_dispatch_task); ++ ++/* dispatch queues */ ++static struct scx_dispatch_q __cacheline_aligned_in_smp scx_dsq_global; ++ ++static const struct rhashtable_params dsq_hash_params = { ++ .key_len = 8, ++ .key_offset = offsetof(struct scx_dispatch_q, id), ++ .head_offset = offsetof(struct scx_dispatch_q, hash_node), ++}; ++ ++static struct rhashtable dsq_hash; ++static LLIST_HEAD(dsqs_to_free); ++ ++/* dispatch buf */ ++struct scx_dsp_buf_ent { ++ struct task_struct *task; ++ unsigned long qseq; ++ u64 dsq_id; ++ u64 enq_flags; ++}; ++ ++static u32 scx_dsp_max_batch; ++ ++struct scx_dsp_ctx { ++ struct rq *rq; ++ u32 cursor; ++ u32 nr_tasks; ++ struct scx_dsp_buf_ent buf[]; ++}; ++ ++static struct scx_dsp_ctx __percpu *scx_dsp_ctx; ++ ++/* string formatting from BPF */ ++struct scx_bstr_buf { ++ u64 data[MAX_BPRINTF_VARARGS]; ++ char line[SCX_EXIT_MSG_LEN]; ++}; ++ ++static DEFINE_RAW_SPINLOCK(scx_exit_bstr_buf_lock); ++static struct scx_bstr_buf scx_exit_bstr_buf; ++ ++/* ops debug dump */ ++struct scx_dump_data { ++ s32 cpu; ++ bool first; ++ s32 cursor; ++ struct seq_buf *s; ++ const char *prefix; ++ struct scx_bstr_buf buf; ++}; ++ ++struct scx_dump_data scx_dump_data = { ++ .cpu = -1, ++}; ++ ++/* /sys/kernel/sched_ext interface */ ++static struct kset *scx_kset; ++static struct kobject *scx_root_kobj; ++ ++#define CREATE_TRACE_POINTS ++#include ++ ++static void process_ddsp_deferred_locals(struct rq *rq); ++static void scx_bpf_kick_cpu(s32 cpu, u64 flags); ++static __printf(3, 4) void scx_ops_exit_kind(enum scx_exit_kind kind, ++ s64 exit_code, ++ const char *fmt, ...); ++ ++#define scx_ops_error_kind(err, fmt, args...) \ ++ scx_ops_exit_kind((err), 0, fmt, ##args) ++ ++#define scx_ops_exit(code, fmt, args...) \ ++ scx_ops_exit_kind(SCX_EXIT_UNREG_KERN, (code), fmt, ##args) ++ ++#define scx_ops_error(fmt, args...) \ ++ scx_ops_error_kind(SCX_EXIT_ERROR, fmt, ##args) ++ ++#define SCX_HAS_OP(op) static_branch_likely(&scx_has_op[SCX_OP_IDX(op)]) ++ ++static long jiffies_delta_msecs(unsigned long at, unsigned long now) ++{ ++ if (time_after(at, now)) ++ return jiffies_to_msecs(at - now); ++ else ++ return -(long)jiffies_to_msecs(now - at); ++} ++ ++/* if the highest set bit is N, return a mask with bits [N+1, 31] set */ ++static u32 higher_bits(u32 flags) ++{ ++ return ~((1 << fls(flags)) - 1); ++} ++ ++/* return the mask with only the highest bit set */ ++static u32 highest_bit(u32 flags) ++{ ++ int bit = fls(flags); ++ return ((u64)1 << bit) >> 1; ++} ++ ++static bool u32_before(u32 a, u32 b) ++{ ++ return (s32)(a - b) < 0; ++} ++ ++/* ++ * scx_kf_mask enforcement. Some kfuncs can only be called from specific SCX ++ * ops. When invoking SCX ops, SCX_CALL_OP[_RET]() should be used to indicate ++ * the allowed kfuncs and those kfuncs should use scx_kf_allowed() to check ++ * whether it's running from an allowed context. ++ * ++ * @mask is constant, always inline to cull the mask calculations. ++ */ ++static __always_inline void scx_kf_allow(u32 mask) ++{ ++ /* nesting is allowed only in increasing scx_kf_mask order */ ++ WARN_ONCE((mask | higher_bits(mask)) & current->scx.kf_mask, ++ "invalid nesting current->scx.kf_mask=0x%x mask=0x%x\n", ++ current->scx.kf_mask, mask); ++ current->scx.kf_mask |= mask; ++ barrier(); ++} ++ ++static void scx_kf_disallow(u32 mask) ++{ ++ barrier(); ++ current->scx.kf_mask &= ~mask; ++} ++ ++#define SCX_CALL_OP(mask, op, args...) \ ++do { \ ++ if (mask) { \ ++ scx_kf_allow(mask); \ ++ scx_ops.op(args); \ ++ scx_kf_disallow(mask); \ ++ } else { \ ++ scx_ops.op(args); \ ++ } \ ++} while (0) ++ ++#define SCX_CALL_OP_RET(mask, op, args...) \ ++({ \ ++ __typeof__(scx_ops.op(args)) __ret; \ ++ if (mask) { \ ++ scx_kf_allow(mask); \ ++ __ret = scx_ops.op(args); \ ++ scx_kf_disallow(mask); \ ++ } else { \ ++ __ret = scx_ops.op(args); \ ++ } \ ++ __ret; \ ++}) ++ ++/* ++ * Some kfuncs are allowed only on the tasks that are subjects of the ++ * in-progress scx_ops operation for, e.g., locking guarantees. To enforce such ++ * restrictions, the following SCX_CALL_OP_*() variants should be used when ++ * invoking scx_ops operations that take task arguments. These can only be used ++ * for non-nesting operations due to the way the tasks are tracked. ++ * ++ * kfuncs which can only operate on such tasks can in turn use ++ * scx_kf_allowed_on_arg_tasks() to test whether the invocation is allowed on ++ * the specific task. ++ */ ++#define SCX_CALL_OP_TASK(mask, op, task, args...) \ ++do { \ ++ BUILD_BUG_ON((mask) & ~__SCX_KF_TERMINAL); \ ++ current->scx.kf_tasks[0] = task; \ ++ SCX_CALL_OP(mask, op, task, ##args); \ ++ current->scx.kf_tasks[0] = NULL; \ ++} while (0) ++ ++#define SCX_CALL_OP_TASK_RET(mask, op, task, args...) \ ++({ \ ++ __typeof__(scx_ops.op(task, ##args)) __ret; \ ++ BUILD_BUG_ON((mask) & ~__SCX_KF_TERMINAL); \ ++ current->scx.kf_tasks[0] = task; \ ++ __ret = SCX_CALL_OP_RET(mask, op, task, ##args); \ ++ current->scx.kf_tasks[0] = NULL; \ ++ __ret; \ ++}) ++ ++#define SCX_CALL_OP_2TASKS_RET(mask, op, task0, task1, args...) \ ++({ \ ++ __typeof__(scx_ops.op(task0, task1, ##args)) __ret; \ ++ BUILD_BUG_ON((mask) & ~__SCX_KF_TERMINAL); \ ++ current->scx.kf_tasks[0] = task0; \ ++ current->scx.kf_tasks[1] = task1; \ ++ __ret = SCX_CALL_OP_RET(mask, op, task0, task1, ##args); \ ++ current->scx.kf_tasks[0] = NULL; \ ++ current->scx.kf_tasks[1] = NULL; \ ++ __ret; \ ++}) ++ ++/* @mask is constant, always inline to cull unnecessary branches */ ++static __always_inline bool scx_kf_allowed(u32 mask) ++{ ++ if (unlikely(!(current->scx.kf_mask & mask))) { ++ scx_ops_error("kfunc with mask 0x%x called from an operation only allowing 0x%x", ++ mask, current->scx.kf_mask); ++ return false; ++ } ++ ++ /* ++ * Enforce nesting boundaries. e.g. A kfunc which can be called from ++ * DISPATCH must not be called if we're running DEQUEUE which is nested ++ * inside ops.dispatch(). We don't need to check boundaries for any ++ * blocking kfuncs as the verifier ensures they're only called from ++ * sleepable progs. ++ */ ++ if (unlikely(highest_bit(mask) == SCX_KF_CPU_RELEASE && ++ (current->scx.kf_mask & higher_bits(SCX_KF_CPU_RELEASE)))) { ++ scx_ops_error("cpu_release kfunc called from a nested operation"); ++ return false; ++ } ++ ++ if (unlikely(highest_bit(mask) == SCX_KF_DISPATCH && ++ (current->scx.kf_mask & higher_bits(SCX_KF_DISPATCH)))) { ++ scx_ops_error("dispatch kfunc called from a nested operation"); ++ return false; ++ } ++ ++ return true; ++} ++ ++/* see SCX_CALL_OP_TASK() */ ++static __always_inline bool scx_kf_allowed_on_arg_tasks(u32 mask, ++ struct task_struct *p) ++{ ++ if (!scx_kf_allowed(mask)) ++ return false; ++ ++ if (unlikely((p != current->scx.kf_tasks[0] && ++ p != current->scx.kf_tasks[1]))) { ++ scx_ops_error("called on a task not being operated on"); ++ return false; ++ } ++ ++ return true; ++} ++ ++/** ++ * nldsq_next_task - Iterate to the next task in a non-local DSQ ++ * @dsq: user dsq being interated ++ * @cur: current position, %NULL to start iteration ++ * @rev: walk backwards ++ * ++ * Returns %NULL when iteration is finished. ++ */ ++static struct task_struct *nldsq_next_task(struct scx_dispatch_q *dsq, ++ struct task_struct *cur, bool rev) ++{ ++ struct list_head *list_node; ++ struct scx_dsq_list_node *dsq_lnode; ++ ++ lockdep_assert_held(&dsq->lock); ++ ++ if (cur) ++ list_node = &cur->scx.dsq_list.node; ++ else ++ list_node = &dsq->list; ++ ++ /* find the next task, need to skip BPF iteration cursors */ ++ do { ++ if (rev) ++ list_node = list_node->prev; ++ else ++ list_node = list_node->next; ++ ++ if (list_node == &dsq->list) ++ return NULL; ++ ++ dsq_lnode = container_of(list_node, struct scx_dsq_list_node, ++ node); ++ } while (dsq_lnode->is_bpf_iter_cursor); ++ ++ return container_of(dsq_lnode, struct task_struct, scx.dsq_list); ++} ++ ++#define nldsq_for_each_task(p, dsq) \ ++ for ((p) = nldsq_next_task((dsq), NULL, false); (p); \ ++ (p) = nldsq_next_task((dsq), (p), false)) ++ ++ ++/* ++ * BPF DSQ iterator. Tasks in a non-local DSQ can be iterated in [reverse] ++ * dispatch order. BPF-visible iterator is opaque and larger to allow future ++ * changes without breaking backward compatibility. Can be used with ++ * bpf_for_each(). See bpf_iter_scx_dsq_*(). ++ */ ++enum scx_dsq_iter_flags { ++ /* iterate in the reverse dispatch order */ ++ SCX_DSQ_ITER_REV = 1U << 0, ++ ++ __SCX_DSQ_ITER_ALL_FLAGS = SCX_DSQ_ITER_REV, ++}; ++ ++struct bpf_iter_scx_dsq_kern { ++ struct scx_dsq_list_node cursor; ++ struct scx_dispatch_q *dsq; ++ u32 dsq_seq; ++ u32 flags; ++} __attribute__((aligned(8))); ++ ++struct bpf_iter_scx_dsq { ++ u64 __opaque[6]; ++} __attribute__((aligned(8))); ++ ++ ++/* ++ * SCX task iterator. ++ */ ++struct scx_task_iter { ++ struct sched_ext_entity cursor; ++ struct task_struct *locked; ++ struct rq *rq; ++ struct rq_flags rf; ++}; ++ ++/** ++ * scx_task_iter_init - Initialize a task iterator ++ * @iter: iterator to init ++ * ++ * Initialize @iter. Must be called with scx_tasks_lock held. Once initialized, ++ * @iter must eventually be exited with scx_task_iter_exit(). ++ * ++ * scx_tasks_lock may be released between this and the first next() call or ++ * between any two next() calls. If scx_tasks_lock is released between two ++ * next() calls, the caller is responsible for ensuring that the task being ++ * iterated remains accessible either through RCU read lock or obtaining a ++ * reference count. ++ * ++ * All tasks which existed when the iteration started are guaranteed to be ++ * visited as long as they still exist. ++ */ ++static void scx_task_iter_init(struct scx_task_iter *iter) ++{ ++ lockdep_assert_held(&scx_tasks_lock); ++ ++ iter->cursor = (struct sched_ext_entity){ .flags = SCX_TASK_CURSOR }; ++ list_add(&iter->cursor.tasks_node, &scx_tasks); ++ iter->locked = NULL; ++} ++ ++/** ++ * scx_task_iter_rq_unlock - Unlock rq locked by a task iterator ++ * @iter: iterator to unlock rq for ++ * ++ * If @iter is in the middle of a locked iteration, it may be locking the rq of ++ * the task currently being visited. Unlock the rq if so. This function can be ++ * safely called anytime during an iteration. ++ * ++ * Returns %true if the rq @iter was locking is unlocked. %false if @iter was ++ * not locking an rq. ++ */ ++static bool scx_task_iter_rq_unlock(struct scx_task_iter *iter) ++{ ++ if (iter->locked) { ++ task_rq_unlock(iter->rq, iter->locked, &iter->rf); ++ iter->locked = NULL; ++ return true; ++ } else { ++ return false; ++ } ++} ++ ++/** ++ * scx_task_iter_exit - Exit a task iterator ++ * @iter: iterator to exit ++ * ++ * Exit a previously initialized @iter. Must be called with scx_tasks_lock held. ++ * If the iterator holds a task's rq lock, that rq lock is released. See ++ * scx_task_iter_init() for details. ++ */ ++static void scx_task_iter_exit(struct scx_task_iter *iter) ++{ ++ lockdep_assert_held(&scx_tasks_lock); ++ ++ scx_task_iter_rq_unlock(iter); ++ list_del_init(&iter->cursor.tasks_node); ++} ++ ++/** ++ * scx_task_iter_next - Next task ++ * @iter: iterator to walk ++ * ++ * Visit the next task. See scx_task_iter_init() for details. ++ */ ++static struct task_struct *scx_task_iter_next(struct scx_task_iter *iter) ++{ ++ struct list_head *cursor = &iter->cursor.tasks_node; ++ struct sched_ext_entity *pos; ++ ++ lockdep_assert_held(&scx_tasks_lock); ++ ++ list_for_each_entry(pos, cursor, tasks_node) { ++ if (&pos->tasks_node == &scx_tasks) ++ return NULL; ++ if (!(pos->flags & SCX_TASK_CURSOR)) { ++ list_move(cursor, &pos->tasks_node); ++ return container_of(pos, struct task_struct, scx); ++ } ++ } ++ ++ /* can't happen, should always terminate at scx_tasks above */ ++ BUG(); ++} ++ ++/** ++ * scx_task_iter_next_locked - Next non-idle task with its rq locked ++ * @iter: iterator to walk ++ * @include_dead: Whether we should include dead tasks in the iteration ++ * ++ * Visit the non-idle task with its rq lock held. Allows callers to specify ++ * whether they would like to filter out dead tasks. See scx_task_iter_init() ++ * for details. ++ */ ++static struct task_struct * ++scx_task_iter_next_locked(struct scx_task_iter *iter, bool include_dead) ++{ ++ struct task_struct *p; ++retry: ++ scx_task_iter_rq_unlock(iter); ++ ++ while ((p = scx_task_iter_next(iter))) { ++ /* ++ * is_idle_task() tests %PF_IDLE which may not be set for CPUs ++ * which haven't yet been onlined. Test sched_class directly. ++ */ ++ if (p->sched_class != &idle_sched_class) ++ break; ++ } ++ if (!p) ++ return NULL; ++ ++ iter->rq = task_rq_lock(p, &iter->rf); ++ iter->locked = p; ++ ++ /* ++ * If we see %TASK_DEAD, @p already disabled preemption, is about to do ++ * the final __schedule(), won't ever need to be scheduled again and can ++ * thus be safely ignored. If we don't see %TASK_DEAD, @p can't enter ++ * the final __schedle() while we're locking its rq and thus will stay ++ * alive until the rq is unlocked. ++ */ ++ if (!include_dead && READ_ONCE(p->__state) == TASK_DEAD) ++ goto retry; ++ ++ return p; ++} ++ ++static enum scx_ops_enable_state scx_ops_enable_state(void) ++{ ++ return atomic_read(&scx_ops_enable_state_var); ++} ++ ++static enum scx_ops_enable_state ++scx_ops_set_enable_state(enum scx_ops_enable_state to) ++{ ++ return atomic_xchg(&scx_ops_enable_state_var, to); ++} ++ ++static bool scx_ops_tryset_enable_state(enum scx_ops_enable_state to, ++ enum scx_ops_enable_state from) ++{ ++ int from_v = from; ++ ++ return atomic_try_cmpxchg(&scx_ops_enable_state_var, &from_v, to); ++} ++ ++static bool scx_ops_bypassing(void) ++{ ++ return unlikely(atomic_read(&scx_ops_bypass_depth)); ++} ++ ++/** ++ * wait_ops_state - Busy-wait the specified ops state to end ++ * @p: target task ++ * @opss: state to wait the end of ++ * ++ * Busy-wait for @p to transition out of @opss. This can only be used when the ++ * state part of @opss is %SCX_QUEUEING or %SCX_DISPATCHING. This function also ++ * has load_acquire semantics to ensure that the caller can see the updates made ++ * in the enqueueing and dispatching paths. ++ */ ++static void wait_ops_state(struct task_struct *p, unsigned long opss) ++{ ++ do { ++ cpu_relax(); ++ } while (atomic_long_read_acquire(&p->scx.ops_state) == opss); ++} ++ ++/** ++ * ops_cpu_valid - Verify a cpu number ++ * @cpu: cpu number which came from a BPF ops ++ * @where: extra information reported on error ++ * ++ * @cpu is a cpu number which came from the BPF scheduler and can be any value. ++ * Verify that it is in range and one of the possible cpus. If invalid, trigger ++ * an ops error. ++ */ ++static bool ops_cpu_valid(s32 cpu, const char *where) ++{ ++ if (likely(cpu >= 0 && cpu < nr_cpu_ids && cpu_possible(cpu))) { ++ return true; ++ } else { ++ scx_ops_error("invalid CPU %d%s%s", cpu, ++ where ? " " : "", where ?: ""); ++ return false; ++ } ++} ++ ++/** ++ * ops_sanitize_err - Sanitize a -errno value ++ * @ops_name: operation to blame on failure ++ * @err: -errno value to sanitize ++ * ++ * Verify @err is a valid -errno. If not, trigger scx_ops_error() and return ++ * -%EPROTO. This is necessary because returning a rogue -errno up the chain can ++ * cause misbehaviors. For an example, a large negative return from ++ * ops.init_task() triggers an oops when passed up the call chain because the ++ * value fails IS_ERR() test after being encoded with ERR_PTR() and then is ++ * handled as a pointer. ++ */ ++static int ops_sanitize_err(const char *ops_name, s32 err) ++{ ++ if (err < 0 && err >= -MAX_ERRNO) ++ return err; ++ ++ scx_ops_error("ops.%s() returned an invalid errno %d", ops_name, err); ++ return -EPROTO; ++} ++ ++static void run_deferred(struct rq *rq) ++{ ++ process_ddsp_deferred_locals(rq); ++} ++ ++#ifdef CONFIG_SMP ++static void deferred_bal_cb_workfn(struct rq *rq) ++{ ++ run_deferred(rq); ++} ++#endif ++ ++static void deferred_irq_workfn(struct irq_work *irq_work) ++{ ++ struct rq *rq = container_of(irq_work, struct rq, scx.deferred_irq_work); ++ ++ raw_spin_rq_lock(rq); ++ run_deferred(rq); ++ raw_spin_rq_unlock(rq); ++} ++ ++/** ++ * schedule_deferred - Schedule execution of deferred actions on an rq ++ * @rq: target rq ++ * ++ * Schedule execution of deferred actions on @rq. Must be called with @rq ++ * locked. Deferred actions are executed with @rq locked but unpinned, and thus ++ * can unlock @rq to e.g. migrate tasks to other rqs. ++ */ ++static void schedule_deferred(struct rq *rq) ++{ ++ lockdep_assert_rq_held(rq); ++ ++#ifdef CONFIG_SMP ++ /* ++ * If in the middle of waking up a task, task_woken_scx() will be called ++ * afterwards which will then run the deferred actions, no need to ++ * schedule anything. ++ */ ++ if (rq->scx.flags & SCX_RQ_IN_WAKEUP) ++ return; ++ ++ /* ++ * If in balance, the balance callbacks will be called before rq lock is ++ * released. Schedule one. ++ */ ++ if (rq->scx.flags & SCX_RQ_IN_BALANCE) { ++ queue_balance_callback(rq, &rq->scx.deferred_bal_cb, ++ deferred_bal_cb_workfn); ++ return; ++ } ++#endif ++ /* ++ * No scheduler hooks available. Queue an irq work. They are executed on ++ * IRQ re-enable which may take a bit longer than the scheduler hooks. ++ * The above WAKEUP and BALANCE paths should cover most of the cases and ++ * the time to IRQ re-enable shouldn't be long. ++ */ ++ irq_work_queue(&rq->scx.deferred_irq_work); ++} ++ ++/** ++ * touch_core_sched - Update timestamp used for core-sched task ordering ++ * @rq: rq to read clock from, must be locked ++ * @p: task to update the timestamp for ++ * ++ * Update @p->scx.core_sched_at timestamp. This is used by scx_prio_less() to ++ * implement global or local-DSQ FIFO ordering for core-sched. Should be called ++ * when a task becomes runnable and its turn on the CPU ends (e.g. slice ++ * exhaustion). ++ */ ++static void touch_core_sched(struct rq *rq, struct task_struct *p) ++{ ++#ifdef CONFIG_SCHED_CORE ++ /* ++ * It's okay to update the timestamp spuriously. Use ++ * sched_core_disabled() which is cheaper than enabled(). ++ */ ++ if (!sched_core_disabled()) ++ p->scx.core_sched_at = rq_clock_task(rq); ++#endif ++} ++ ++/** ++ * touch_core_sched_dispatch - Update core-sched timestamp on dispatch ++ * @rq: rq to read clock from, must be locked ++ * @p: task being dispatched ++ * ++ * If the BPF scheduler implements custom core-sched ordering via ++ * ops.core_sched_before(), @p->scx.core_sched_at is used to implement FIFO ++ * ordering within each local DSQ. This function is called from dispatch paths ++ * and updates @p->scx.core_sched_at if custom core-sched ordering is in effect. ++ */ ++static void touch_core_sched_dispatch(struct rq *rq, struct task_struct *p) ++{ ++ lockdep_assert_rq_held(rq); ++ assert_clock_updated(rq); ++ ++#ifdef CONFIG_SCHED_CORE ++ if (SCX_HAS_OP(core_sched_before)) ++ touch_core_sched(rq, p); ++#endif ++} ++ ++static void update_curr_scx(struct rq *rq) ++{ ++ struct task_struct *curr = rq->curr; ++ u64 now = rq_clock_task(rq); ++ u64 delta_exec; ++ ++ if (time_before_eq64(now, curr->se.exec_start)) ++ return; ++ ++ delta_exec = now - curr->se.exec_start; ++ curr->se.exec_start = now; ++ curr->se.sum_exec_runtime += delta_exec; ++ account_group_exec_runtime(curr, delta_exec); ++ cgroup_account_cputime(curr, delta_exec); ++ ++ if (curr->scx.slice != SCX_SLICE_INF) { ++ curr->scx.slice -= min(curr->scx.slice, delta_exec); ++ if (!curr->scx.slice) ++ touch_core_sched(rq, curr); ++ } ++} ++ ++static bool scx_dsq_priq_less(struct rb_node *node_a, ++ const struct rb_node *node_b) ++{ ++ const struct task_struct *a = ++ container_of(node_a, struct task_struct, scx.dsq_priq); ++ const struct task_struct *b = ++ container_of(node_b, struct task_struct, scx.dsq_priq); ++ ++ return time_before64(a->scx.dsq_vtime, b->scx.dsq_vtime); ++} ++ ++static void dsq_mod_nr(struct scx_dispatch_q *dsq, s32 delta) ++{ ++ /* scx_bpf_dsq_nr_queued() reads ->nr without locking, use WRITE_ONCE() */ ++ WRITE_ONCE(dsq->nr, dsq->nr + delta); ++} ++ ++static void dispatch_enqueue(struct scx_dispatch_q *dsq, struct task_struct *p, ++ u64 enq_flags) ++{ ++ bool is_local = dsq->id == SCX_DSQ_LOCAL; ++ ++ WARN_ON_ONCE(p->scx.dsq || !list_empty(&p->scx.dsq_list.node)); ++ WARN_ON_ONCE((p->scx.dsq_flags & SCX_TASK_DSQ_ON_PRIQ) || ++ !RB_EMPTY_NODE(&p->scx.dsq_priq)); ++ ++ if (!is_local) { ++ raw_spin_lock(&dsq->lock); ++ if (unlikely(dsq->id == SCX_DSQ_INVALID)) { ++ scx_ops_error("attempting to dispatch to a destroyed dsq"); ++ /* fall back to the global dsq */ ++ raw_spin_unlock(&dsq->lock); ++ dsq = &scx_dsq_global; ++ raw_spin_lock(&dsq->lock); ++ } ++ } ++ ++ if (unlikely((dsq->id & SCX_DSQ_FLAG_BUILTIN) && ++ (enq_flags & SCX_ENQ_DSQ_PRIQ))) { ++ /* ++ * SCX_DSQ_LOCAL and SCX_DSQ_GLOBAL DSQs always consume from ++ * their FIFO queues. To avoid confusion and accidentally ++ * starving vtime-dispatched tasks by FIFO-dispatched tasks, we ++ * disallow any internal DSQ from doing vtime ordering of ++ * tasks. ++ */ ++ scx_ops_error("cannot use vtime ordering for built-in DSQs"); ++ enq_flags &= ~SCX_ENQ_DSQ_PRIQ; ++ } ++ ++ if (enq_flags & SCX_ENQ_DSQ_PRIQ) { ++ struct rb_node *rbp; ++ ++ /* ++ * A PRIQ DSQ shouldn't be using FIFO enqueueing. As tasks are ++ * linked to both the rbtree and list on PRIQs, this can only be ++ * tested easily when adding the first task. ++ */ ++ if (unlikely(RB_EMPTY_ROOT(&dsq->priq) && ++ nldsq_next_task(dsq, NULL, false))) ++ scx_ops_error("DSQ ID 0x%016llx already had FIFO-enqueued tasks", ++ dsq->id); ++ ++ p->scx.dsq_flags |= SCX_TASK_DSQ_ON_PRIQ; ++ rb_add(&p->scx.dsq_priq, &dsq->priq, scx_dsq_priq_less); ++ ++ /* ++ * Find the previous task and insert after it on the list so ++ * that @dsq->list is vtime ordered. ++ */ ++ rbp = rb_prev(&p->scx.dsq_priq); ++ if (rbp) { ++ struct task_struct *prev = ++ container_of(rbp, struct task_struct, ++ scx.dsq_priq); ++ list_add(&p->scx.dsq_list.node, &prev->scx.dsq_list.node); ++ } else { ++ list_add(&p->scx.dsq_list.node, &dsq->list); ++ } ++ } else { ++ /* a FIFO DSQ shouldn't be using PRIQ enqueuing */ ++ if (unlikely(!RB_EMPTY_ROOT(&dsq->priq))) ++ scx_ops_error("DSQ ID 0x%016llx already had PRIQ-enqueued tasks", ++ dsq->id); ++ ++ if (enq_flags & (SCX_ENQ_HEAD | SCX_ENQ_PREEMPT)) ++ list_add(&p->scx.dsq_list.node, &dsq->list); ++ else ++ list_add_tail(&p->scx.dsq_list.node, &dsq->list); ++ } ++ ++ /* seq records the order tasks are queued, used by BPF DSQ iterator */ ++ dsq->seq++; ++ p->scx.dsq_seq = dsq->seq; ++ ++ dsq_mod_nr(dsq, 1); ++ p->scx.dsq = dsq; ++ ++ /* ++ * scx.ddsp_dsq_id and scx.ddsp_enq_flags are only relevant on the ++ * direct dispatch path, but we clear them here because the direct ++ * dispatch verdict may be overridden on the enqueue path during e.g. ++ * bypass. ++ */ ++ p->scx.ddsp_dsq_id = SCX_DSQ_INVALID; ++ p->scx.ddsp_enq_flags = 0; ++ ++ /* ++ * We're transitioning out of QUEUEING or DISPATCHING. store_release to ++ * match waiters' load_acquire. ++ */ ++ if (enq_flags & SCX_ENQ_CLEAR_OPSS) ++ atomic_long_set_release(&p->scx.ops_state, SCX_OPSS_NONE); ++ ++ if (is_local) { ++ struct rq *rq = container_of(dsq, struct rq, scx.local_dsq); ++ bool preempt = false; ++ ++ if ((enq_flags & SCX_ENQ_PREEMPT) && p != rq->curr && ++ rq->curr->sched_class == &ext_sched_class) { ++ rq->curr->scx.slice = 0; ++ preempt = true; ++ } ++ ++ if (preempt || sched_class_above(&ext_sched_class, ++ rq->curr->sched_class)) ++ resched_curr(rq); ++ } else { ++ raw_spin_unlock(&dsq->lock); ++ } ++} ++ ++static void task_unlink_from_dsq(struct task_struct *p, ++ struct scx_dispatch_q *dsq) ++{ ++ if (p->scx.dsq_flags & SCX_TASK_DSQ_ON_PRIQ) { ++ rb_erase(&p->scx.dsq_priq, &dsq->priq); ++ RB_CLEAR_NODE(&p->scx.dsq_priq); ++ p->scx.dsq_flags &= ~SCX_TASK_DSQ_ON_PRIQ; ++ } ++ ++ list_del_init(&p->scx.dsq_list.node); ++} ++ ++static void dispatch_dequeue(struct rq *rq, struct task_struct *p) ++{ ++ struct scx_dispatch_q *dsq = p->scx.dsq; ++ bool is_local = dsq == &rq->scx.local_dsq; ++ ++ if (!dsq) { ++ /* ++ * If !dsq && on-list, @p is on @rq's ddsp_deferred_locals. ++ * Unlinking is all that's needed to cancel. ++ */ ++ if (unlikely(!list_empty(&p->scx.dsq_list.node))) ++ list_del_init(&p->scx.dsq_list.node); ++ ++ /* ++ * When dispatching directly from the BPF scheduler to a local ++ * DSQ, the task isn't associated with any DSQ but ++ * @p->scx.holding_cpu may be set under the protection of ++ * %SCX_OPSS_DISPATCHING. ++ */ ++ if (p->scx.holding_cpu >= 0) ++ p->scx.holding_cpu = -1; ++ ++ return; ++ } ++ ++ if (!is_local) ++ raw_spin_lock(&dsq->lock); ++ ++ /* ++ * Now that we hold @dsq->lock, @p->holding_cpu and @p->scx.dsq_* can't ++ * change underneath us. ++ */ ++ if (p->scx.holding_cpu < 0) { ++ /* @p must still be on @dsq, dequeue */ ++ WARN_ON_ONCE(list_empty(&p->scx.dsq_list.node)); ++ task_unlink_from_dsq(p, dsq); ++ dsq_mod_nr(dsq, -1); ++ } else { ++ /* ++ * We're racing against dispatch_to_local_dsq() which already ++ * removed @p from @dsq and set @p->scx.holding_cpu. Clear the ++ * holding_cpu which tells dispatch_to_local_dsq() that it lost ++ * the race. ++ */ ++ WARN_ON_ONCE(!list_empty(&p->scx.dsq_list.node)); ++ p->scx.holding_cpu = -1; ++ } ++ p->scx.dsq = NULL; ++ ++ if (!is_local) ++ raw_spin_unlock(&dsq->lock); ++} ++ ++static struct scx_dispatch_q *find_user_dsq(u64 dsq_id) ++{ ++ return rhashtable_lookup_fast(&dsq_hash, &dsq_id, dsq_hash_params); ++} ++ ++static struct scx_dispatch_q *find_non_local_dsq(u64 dsq_id) ++{ ++ lockdep_assert(rcu_read_lock_any_held()); ++ ++ if (dsq_id == SCX_DSQ_GLOBAL) ++ return &scx_dsq_global; ++ else ++ return find_user_dsq(dsq_id); ++} ++ ++static struct scx_dispatch_q *find_dsq_for_dispatch(struct rq *rq, u64 dsq_id, ++ struct task_struct *p) ++{ ++ struct scx_dispatch_q *dsq; ++ ++ if (dsq_id == SCX_DSQ_LOCAL) ++ return &rq->scx.local_dsq; ++ ++ dsq = find_non_local_dsq(dsq_id); ++ if (unlikely(!dsq)) { ++ scx_ops_error("non-existent DSQ 0x%llx for %s[%d]", ++ dsq_id, p->comm, p->pid); ++ return &scx_dsq_global; ++ } ++ ++ return dsq; ++} ++ ++static void mark_direct_dispatch(struct task_struct *ddsp_task, ++ struct task_struct *p, u64 dsq_id, ++ u64 enq_flags) ++{ ++ /* ++ * Mark that dispatch already happened from ops.select_cpu() or ++ * ops.enqueue() by spoiling direct_dispatch_task with a non-NULL value ++ * which can never match a valid task pointer. ++ */ ++ __this_cpu_write(direct_dispatch_task, ERR_PTR(-ESRCH)); ++ ++ /* @p must match the task on the enqueue path */ ++ if (unlikely(p != ddsp_task)) { ++ if (IS_ERR(ddsp_task)) ++ scx_ops_error("%s[%d] already direct-dispatched", ++ p->comm, p->pid); ++ else ++ scx_ops_error("scheduling for %s[%d] but trying to direct-dispatch %s[%d]", ++ ddsp_task->comm, ddsp_task->pid, ++ p->comm, p->pid); ++ return; ++ } ++ ++ WARN_ON_ONCE(p->scx.ddsp_dsq_id != SCX_DSQ_INVALID); ++ WARN_ON_ONCE(p->scx.ddsp_enq_flags); ++ ++ p->scx.ddsp_dsq_id = dsq_id; ++ p->scx.ddsp_enq_flags = enq_flags; ++} ++ ++static void direct_dispatch(struct task_struct *p, u64 enq_flags) ++{ ++ struct rq *rq = task_rq(p); ++ struct scx_dispatch_q *dsq; ++ u64 dsq_id = p->scx.ddsp_dsq_id; ++ ++ touch_core_sched_dispatch(rq, p); ++ ++ p->scx.ddsp_enq_flags |= enq_flags; ++ ++ /* ++ * We are in the enqueue path with @rq locked and pinned, and thus can't ++ * double lock a remote rq and enqueue to its local DSQ. For ++ * DSQ_LOCAL_ON verdicts targeting the local DSQ of a remote CPU, defer ++ * the enqueue so that it's executed when @rq can be unlocked. ++ */ ++ if ((dsq_id & SCX_DSQ_LOCAL_ON) == SCX_DSQ_LOCAL_ON) { ++ s32 cpu = dsq_id & SCX_DSQ_LOCAL_CPU_MASK; ++ unsigned long opss; ++ ++ if (cpu == cpu_of(rq)) { ++ dsq_id = SCX_DSQ_LOCAL; ++ goto dispatch; ++ } ++ ++ opss = atomic_long_read(&p->scx.ops_state) & SCX_OPSS_STATE_MASK; ++ ++ switch (opss & SCX_OPSS_STATE_MASK) { ++ case SCX_OPSS_NONE: ++ break; ++ case SCX_OPSS_QUEUEING: ++ /* ++ * As @p was never passed to the BPF side, _release is ++ * not strictly necessary. Still do it for consistency. ++ */ ++ atomic_long_set_release(&p->scx.ops_state, SCX_OPSS_NONE); ++ break; ++ default: ++ WARN_ONCE(true, "sched_ext: %s[%d] has invalid ops state 0x%lx in direct_dispatch()", ++ p->comm, p->pid, opss); ++ atomic_long_set_release(&p->scx.ops_state, SCX_OPSS_NONE); ++ break; ++ } ++ ++ WARN_ON_ONCE(p->scx.dsq || !list_empty(&p->scx.dsq_list.node)); ++ list_add_tail(&p->scx.dsq_list.node, ++ &rq->scx.ddsp_deferred_locals); ++ schedule_deferred(rq); ++ return; ++ } ++ ++dispatch: ++ dsq = find_dsq_for_dispatch(rq, dsq_id, p); ++ dispatch_enqueue(dsq, p, p->scx.ddsp_enq_flags | SCX_ENQ_CLEAR_OPSS); ++} ++ ++static bool scx_rq_online(struct rq *rq) ++{ ++ return likely(rq->scx.flags & SCX_RQ_ONLINE); ++} ++ ++static void do_enqueue_task(struct rq *rq, struct task_struct *p, u64 enq_flags, ++ int sticky_cpu) ++{ ++ struct task_struct **ddsp_taskp; ++ unsigned long qseq; ++ ++ WARN_ON_ONCE(!(p->scx.flags & SCX_TASK_QUEUED)); ++ ++ /* rq migration */ ++ if (sticky_cpu == cpu_of(rq)) ++ goto local_norefill; ++ ++ /* ++ * If !scx_rq_online(), we already told the BPF scheduler that the CPU ++ * is offline and are just running the hotplug path. Don't bother the ++ * BPF scheduler. ++ */ ++ if (!scx_rq_online(rq)) ++ goto local; ++ ++ if (scx_ops_bypassing()) { ++ if (enq_flags & SCX_ENQ_LAST) ++ goto local; ++ else ++ goto global; ++ } ++ ++ if (p->scx.ddsp_dsq_id != SCX_DSQ_INVALID) ++ goto direct; ++ ++ /* see %SCX_OPS_ENQ_EXITING */ ++ if (!static_branch_unlikely(&scx_ops_enq_exiting) && ++ unlikely(p->flags & PF_EXITING)) ++ goto local; ++ ++ /* see %SCX_OPS_ENQ_LAST */ ++ if (!static_branch_unlikely(&scx_ops_enq_last) && ++ (enq_flags & SCX_ENQ_LAST)) ++ goto local; ++ ++ if (!SCX_HAS_OP(enqueue)) ++ goto global; ++ ++ /* DSQ bypass didn't trigger, enqueue on the BPF scheduler */ ++ qseq = rq->scx.ops_qseq++ << SCX_OPSS_QSEQ_SHIFT; ++ ++ WARN_ON_ONCE(atomic_long_read(&p->scx.ops_state) != SCX_OPSS_NONE); ++ atomic_long_set(&p->scx.ops_state, SCX_OPSS_QUEUEING | qseq); ++ ++ ddsp_taskp = this_cpu_ptr(&direct_dispatch_task); ++ WARN_ON_ONCE(*ddsp_taskp); ++ *ddsp_taskp = p; ++ ++ SCX_CALL_OP_TASK(SCX_KF_ENQUEUE, enqueue, p, enq_flags); ++ ++ *ddsp_taskp = NULL; ++ if (p->scx.ddsp_dsq_id != SCX_DSQ_INVALID) ++ goto direct; ++ ++ /* ++ * If not directly dispatched, QUEUEING isn't clear yet and dispatch or ++ * dequeue may be waiting. The store_release matches their load_acquire. ++ */ ++ atomic_long_set_release(&p->scx.ops_state, SCX_OPSS_QUEUED | qseq); ++ return; ++ ++direct: ++ direct_dispatch(p, enq_flags); ++ return; ++ ++local: ++ /* ++ * For task-ordering, slice refill must be treated as implying the end ++ * of the current slice. Otherwise, the longer @p stays on the CPU, the ++ * higher priority it becomes from scx_prio_less()'s POV. ++ */ ++ touch_core_sched(rq, p); ++ p->scx.slice = SCX_SLICE_DFL; ++local_norefill: ++ dispatch_enqueue(&rq->scx.local_dsq, p, enq_flags); ++ return; ++ ++global: ++ touch_core_sched(rq, p); /* see the comment in local: */ ++ p->scx.slice = SCX_SLICE_DFL; ++ dispatch_enqueue(&scx_dsq_global, p, enq_flags); ++} ++ ++static bool task_runnable(const struct task_struct *p) ++{ ++ return !list_empty(&p->scx.runnable_node); ++} ++ ++static void set_task_runnable(struct rq *rq, struct task_struct *p) ++{ ++ lockdep_assert_rq_held(rq); ++ ++ if (p->scx.flags & SCX_TASK_RESET_RUNNABLE_AT) { ++ p->scx.runnable_at = jiffies; ++ p->scx.flags &= ~SCX_TASK_RESET_RUNNABLE_AT; ++ } ++ ++ /* ++ * list_add_tail() must be used. scx_ops_bypass() depends on tasks being ++ * appened to the runnable_list. ++ */ ++ list_add_tail(&p->scx.runnable_node, &rq->scx.runnable_list); ++} ++ ++static void clr_task_runnable(struct task_struct *p, bool reset_runnable_at) ++{ ++ list_del_init(&p->scx.runnable_node); ++ if (reset_runnable_at) ++ p->scx.flags |= SCX_TASK_RESET_RUNNABLE_AT; ++} ++ ++static void enqueue_task_scx(struct rq *rq, struct task_struct *p, int enq_flags) ++{ ++ int sticky_cpu = p->scx.sticky_cpu; ++ ++ if (enq_flags & ENQUEUE_WAKEUP) ++ rq->scx.flags |= SCX_RQ_IN_WAKEUP; ++ ++ enq_flags |= rq->scx.extra_enq_flags; ++ ++ if (sticky_cpu >= 0) ++ p->scx.sticky_cpu = -1; ++ ++ /* ++ * Restoring a running task will be immediately followed by ++ * set_next_task_scx() which expects the task to not be on the BPF ++ * scheduler as tasks can only start running through local DSQs. Force ++ * direct-dispatch into the local DSQ by setting the sticky_cpu. ++ */ ++ if (unlikely(enq_flags & ENQUEUE_RESTORE) && task_current(rq, p)) ++ sticky_cpu = cpu_of(rq); ++ ++ if (p->scx.flags & SCX_TASK_QUEUED) { ++ WARN_ON_ONCE(!task_runnable(p)); ++ goto out; ++ } ++ ++ set_task_runnable(rq, p); ++ p->scx.flags |= SCX_TASK_QUEUED; ++ rq->scx.nr_running++; ++ add_nr_running(rq, 1); ++ ++ if (SCX_HAS_OP(runnable)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, runnable, p, enq_flags); ++ ++ if (enq_flags & SCX_ENQ_WAKEUP) ++ touch_core_sched(rq, p); ++ ++ do_enqueue_task(rq, p, enq_flags, sticky_cpu); ++out: ++ rq->scx.flags &= ~SCX_RQ_IN_WAKEUP; ++} ++ ++static void ops_dequeue(struct task_struct *p, u64 deq_flags) ++{ ++ unsigned long opss; ++ ++ /* dequeue is always temporary, don't reset runnable_at */ ++ clr_task_runnable(p, false); ++ ++ /* acquire ensures that we see the preceding updates on QUEUED */ ++ opss = atomic_long_read_acquire(&p->scx.ops_state); ++ ++ switch (opss & SCX_OPSS_STATE_MASK) { ++ case SCX_OPSS_NONE: ++ break; ++ case SCX_OPSS_QUEUEING: ++ /* ++ * QUEUEING is started and finished while holding @p's rq lock. ++ * As we're holding the rq lock now, we shouldn't see QUEUEING. ++ */ ++ BUG(); ++ case SCX_OPSS_QUEUED: ++ if (SCX_HAS_OP(dequeue)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, dequeue, p, deq_flags); ++ ++ if (atomic_long_try_cmpxchg(&p->scx.ops_state, &opss, ++ SCX_OPSS_NONE)) ++ break; ++ fallthrough; ++ case SCX_OPSS_DISPATCHING: ++ /* ++ * If @p is being dispatched from the BPF scheduler to a DSQ, ++ * wait for the transfer to complete so that @p doesn't get ++ * added to its DSQ after dequeueing is complete. ++ * ++ * As we're waiting on DISPATCHING with the rq locked, the ++ * dispatching side shouldn't try to lock the rq while ++ * DISPATCHING is set. See dispatch_to_local_dsq(). ++ * ++ * DISPATCHING shouldn't have qseq set and control can reach ++ * here with NONE @opss from the above QUEUED case block. ++ * Explicitly wait on %SCX_OPSS_DISPATCHING instead of @opss. ++ */ ++ wait_ops_state(p, SCX_OPSS_DISPATCHING); ++ BUG_ON(atomic_long_read(&p->scx.ops_state) != SCX_OPSS_NONE); ++ break; ++ } ++} ++ ++static void dequeue_task_scx(struct rq *rq, struct task_struct *p, int deq_flags) ++{ ++ if (!(p->scx.flags & SCX_TASK_QUEUED)) { ++ WARN_ON_ONCE(task_runnable(p)); ++ return; ++ } ++ ++ ops_dequeue(p, deq_flags); ++ ++ /* ++ * A currently running task which is going off @rq first gets dequeued ++ * and then stops running. As we want running <-> stopping transitions ++ * to be contained within runnable <-> quiescent transitions, trigger ++ * ->stopping() early here instead of in put_prev_task_scx(). ++ * ++ * @p may go through multiple stopping <-> running transitions between ++ * here and put_prev_task_scx() if task attribute changes occur while ++ * balance_scx() leaves @rq unlocked. However, they don't contain any ++ * information meaningful to the BPF scheduler and can be suppressed by ++ * skipping the callbacks if the task is !QUEUED. ++ */ ++ if (SCX_HAS_OP(stopping) && task_current(rq, p)) { ++ update_curr_scx(rq); ++ SCX_CALL_OP_TASK(SCX_KF_REST, stopping, p, false); ++ } ++ ++ if (SCX_HAS_OP(quiescent)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, quiescent, p, deq_flags); ++ ++ if (deq_flags & SCX_DEQ_SLEEP) ++ p->scx.flags |= SCX_TASK_DEQD_FOR_SLEEP; ++ else ++ p->scx.flags &= ~SCX_TASK_DEQD_FOR_SLEEP; ++ ++ p->scx.flags &= ~SCX_TASK_QUEUED; ++ rq->scx.nr_running--; ++ sub_nr_running(rq, 1); ++ ++ dispatch_dequeue(rq, p); ++} ++ ++static void yield_task_scx(struct rq *rq) ++{ ++ struct task_struct *p = rq->curr; ++ ++ if (SCX_HAS_OP(yield)) ++ SCX_CALL_OP_2TASKS_RET(SCX_KF_REST, yield, p, NULL); ++ else ++ p->scx.slice = 0; ++} ++ ++static bool yield_to_task_scx(struct rq *rq, struct task_struct *to) ++{ ++ struct task_struct *from = rq->curr; ++ ++ if (SCX_HAS_OP(yield)) ++ return SCX_CALL_OP_2TASKS_RET(SCX_KF_REST, yield, from, to); ++ else ++ return false; ++} ++ ++#ifdef CONFIG_SMP ++/** ++ * move_task_to_local_dsq - Move a task from a different rq to a local DSQ ++ * @rq: rq to move the task into, currently locked ++ * @p: task to move ++ * @enq_flags: %SCX_ENQ_* ++ * ++ * Move @p which is currently on a different rq to @rq's local DSQ. The caller ++ * must: ++ * ++ * 1. Start with exclusive access to @p either through its DSQ lock or ++ * %SCX_OPSS_DISPATCHING flag. ++ * ++ * 2. Set @p->scx.holding_cpu to raw_smp_processor_id(). ++ * ++ * 3. Remember task_rq(@p). Release the exclusive access so that we don't ++ * deadlock with dequeue. ++ * ++ * 4. Lock @rq and the task_rq from #3. ++ * ++ * 5. Call this function. ++ * ++ * Returns %true if @p was successfully moved. %false after racing dequeue and ++ * losing. ++ */ ++static bool move_task_to_local_dsq(struct rq *rq, struct task_struct *p, ++ u64 enq_flags) ++{ ++ struct rq *task_rq; ++ ++ lockdep_assert_rq_held(rq); ++ ++ /* ++ * If dequeue got to @p while we were trying to lock both rq's, it'd ++ * have cleared @p->scx.holding_cpu to -1. While other cpus may have ++ * updated it to different values afterwards, as this operation can't be ++ * preempted or recurse, @p->scx.holding_cpu can never become ++ * raw_smp_processor_id() again before we're done. Thus, we can tell ++ * whether we lost to dequeue by testing whether @p->scx.holding_cpu is ++ * still raw_smp_processor_id(). ++ * ++ * See dispatch_dequeue() for the counterpart. ++ */ ++ if (unlikely(p->scx.holding_cpu != raw_smp_processor_id())) ++ return false; ++ ++ /* @p->rq couldn't have changed if we're still the holding cpu */ ++ task_rq = task_rq(p); ++ lockdep_assert_rq_held(task_rq); ++ ++ WARN_ON_ONCE(!cpumask_test_cpu(cpu_of(rq), p->cpus_ptr)); ++ deactivate_task(task_rq, p, 0); ++ set_task_cpu(p, cpu_of(rq)); ++ p->scx.sticky_cpu = cpu_of(rq); ++ ++ /* ++ * We want to pass scx-specific enq_flags but activate_task() will ++ * truncate the upper 32 bit. As we own @rq, we can pass them through ++ * @rq->scx.extra_enq_flags instead. ++ */ ++ WARN_ON_ONCE(rq->scx.extra_enq_flags); ++ rq->scx.extra_enq_flags = enq_flags; ++ activate_task(rq, p, 0); ++ rq->scx.extra_enq_flags = 0; ++ ++ return true; ++} ++ ++/** ++ * dispatch_to_local_dsq_lock - Ensure source and destination rq's are locked ++ * @rq: current rq which is locked ++ * @src_rq: rq to move task from ++ * @dst_rq: rq to move task to ++ * ++ * We're holding @rq lock and trying to dispatch a task from @src_rq to ++ * @dst_rq's local DSQ and thus need to lock both @src_rq and @dst_rq. Whether ++ * @rq stays locked isn't important as long as the state is restored after ++ * dispatch_to_local_dsq_unlock(). ++ */ ++static void dispatch_to_local_dsq_lock(struct rq *rq, struct rq *src_rq, ++ struct rq *dst_rq) ++{ ++ if (src_rq == dst_rq) { ++ raw_spin_rq_unlock(rq); ++ raw_spin_rq_lock(dst_rq); ++ } else if (rq == src_rq) { ++ double_lock_balance(rq, dst_rq); ++ } else if (rq == dst_rq) { ++ double_lock_balance(rq, src_rq); ++ } else { ++ raw_spin_rq_unlock(rq); ++ double_rq_lock(src_rq, dst_rq); ++ } ++} ++ ++/** ++ * dispatch_to_local_dsq_unlock - Undo dispatch_to_local_dsq_lock() ++ * @rq: current rq which is locked ++ * @src_rq: rq to move task from ++ * @dst_rq: rq to move task to ++ * ++ * Unlock @src_rq and @dst_rq and ensure that @rq is locked on return. ++ */ ++static void dispatch_to_local_dsq_unlock(struct rq *rq, struct rq *src_rq, ++ struct rq *dst_rq) ++{ ++ if (src_rq == dst_rq) { ++ raw_spin_rq_unlock(dst_rq); ++ raw_spin_rq_lock(rq); ++ } else if (rq == src_rq) { ++ double_unlock_balance(rq, dst_rq); ++ } else if (rq == dst_rq) { ++ double_unlock_balance(rq, src_rq); ++ } else { ++ double_rq_unlock(src_rq, dst_rq); ++ raw_spin_rq_lock(rq); ++ } ++} ++#endif /* CONFIG_SMP */ ++ ++static void consume_local_task(struct rq *rq, struct scx_dispatch_q *dsq, ++ struct task_struct *p) ++{ ++ lockdep_assert_held(&dsq->lock); /* released on return */ ++ ++ /* @dsq is locked and @p is on this rq */ ++ WARN_ON_ONCE(p->scx.holding_cpu >= 0); ++ task_unlink_from_dsq(p, dsq); ++ list_add_tail(&p->scx.dsq_list.node, &rq->scx.local_dsq.list); ++ dsq_mod_nr(dsq, -1); ++ dsq_mod_nr(&rq->scx.local_dsq, 1); ++ p->scx.dsq = &rq->scx.local_dsq; ++ raw_spin_unlock(&dsq->lock); ++} ++ ++#ifdef CONFIG_SMP ++/* ++ * Similar to kernel/sched/core.c::is_cpu_allowed() but we're testing whether @p ++ * can be pulled to @rq. ++ */ ++static bool task_can_run_on_remote_rq(struct task_struct *p, struct rq *rq) ++{ ++ int cpu = cpu_of(rq); ++ ++ if (!cpumask_test_cpu(cpu, p->cpus_ptr)) ++ return false; ++ if (unlikely(is_migration_disabled(p))) ++ return false; ++ if (!(p->flags & PF_KTHREAD) && unlikely(!task_cpu_possible(cpu, p))) ++ return false; ++ if (!scx_rq_online(rq)) ++ return false; ++ return true; ++} ++ ++static bool consume_remote_task(struct rq *rq, struct scx_dispatch_q *dsq, ++ struct task_struct *p, struct rq *task_rq) ++{ ++ bool moved = false; ++ ++ lockdep_assert_held(&dsq->lock); /* released on return */ ++ ++ /* ++ * @dsq is locked and @p is on a remote rq. @p is currently protected by ++ * @dsq->lock. We want to pull @p to @rq but may deadlock if we grab ++ * @task_rq while holding @dsq and @rq locks. As dequeue can't drop the ++ * rq lock or fail, do a little dancing from our side. See ++ * move_task_to_local_dsq(). ++ */ ++ WARN_ON_ONCE(p->scx.holding_cpu >= 0); ++ task_unlink_from_dsq(p, dsq); ++ dsq_mod_nr(dsq, -1); ++ p->scx.holding_cpu = raw_smp_processor_id(); ++ raw_spin_unlock(&dsq->lock); ++ ++ double_lock_balance(rq, task_rq); ++ ++ moved = move_task_to_local_dsq(rq, p, 0); ++ ++ double_unlock_balance(rq, task_rq); ++ ++ return moved; ++} ++#else /* CONFIG_SMP */ ++static bool task_can_run_on_remote_rq(struct task_struct *p, struct rq *rq) { return false; } ++static bool consume_remote_task(struct rq *rq, struct scx_dispatch_q *dsq, ++ struct task_struct *p, struct rq *task_rq) { return false; } ++#endif /* CONFIG_SMP */ ++ ++static bool consume_dispatch_q(struct rq *rq, struct scx_dispatch_q *dsq) ++{ ++ struct task_struct *p; ++retry: ++ /* ++ * The caller can't expect to successfully consume a task if the task's ++ * addition to @dsq isn't guaranteed to be visible somehow. Test ++ * @dsq->list without locking and skip if it seems empty. ++ */ ++ if (list_empty(&dsq->list)) ++ return false; ++ ++ raw_spin_lock(&dsq->lock); ++ ++ nldsq_for_each_task(p, dsq) { ++ struct rq *task_rq = task_rq(p); ++ ++ if (rq == task_rq) { ++ consume_local_task(rq, dsq, p); ++ return true; ++ } ++ ++ if (task_can_run_on_remote_rq(p, rq)) { ++ if (likely(consume_remote_task(rq, dsq, p, task_rq))) ++ return true; ++ goto retry; ++ } ++ } ++ ++ raw_spin_unlock(&dsq->lock); ++ return false; ++} ++ ++enum dispatch_to_local_dsq_ret { ++ DTL_DISPATCHED, /* successfully dispatched */ ++ DTL_LOST, /* lost race to dequeue */ ++ DTL_NOT_LOCAL, /* destination is not a local DSQ */ ++ DTL_INVALID, /* invalid local dsq_id */ ++}; ++ ++/** ++ * dispatch_to_local_dsq - Dispatch a task to a local dsq ++ * @rq: current rq which is locked ++ * @dsq_id: destination dsq ID ++ * @p: task to dispatch ++ * @enq_flags: %SCX_ENQ_* ++ * ++ * We're holding @rq lock and want to dispatch @p to the local DSQ identified by ++ * @dsq_id. This function performs all the synchronization dancing needed ++ * because local DSQs are protected with rq locks. ++ * ++ * The caller must have exclusive ownership of @p (e.g. through ++ * %SCX_OPSS_DISPATCHING). ++ */ ++static enum dispatch_to_local_dsq_ret ++dispatch_to_local_dsq(struct rq *rq, u64 dsq_id, struct task_struct *p, ++ u64 enq_flags) ++{ ++ struct rq *src_rq = task_rq(p); ++ struct rq *dst_rq; ++ ++ /* ++ * We're synchronized against dequeue through DISPATCHING. As @p can't ++ * be dequeued, its task_rq and cpus_allowed are stable too. ++ */ ++ if (dsq_id == SCX_DSQ_LOCAL) { ++ dst_rq = rq; ++ } else if ((dsq_id & SCX_DSQ_LOCAL_ON) == SCX_DSQ_LOCAL_ON) { ++ s32 cpu = dsq_id & SCX_DSQ_LOCAL_CPU_MASK; ++ ++ if (!ops_cpu_valid(cpu, "in SCX_DSQ_LOCAL_ON dispatch verdict")) ++ return DTL_INVALID; ++ dst_rq = cpu_rq(cpu); ++ } else { ++ return DTL_NOT_LOCAL; ++ } ++ ++ /* if dispatching to @rq that @p is already on, no lock dancing needed */ ++ if (rq == src_rq && rq == dst_rq) { ++ dispatch_enqueue(&dst_rq->scx.local_dsq, p, ++ enq_flags | SCX_ENQ_CLEAR_OPSS); ++ return DTL_DISPATCHED; ++ } ++ ++#ifdef CONFIG_SMP ++ if (cpumask_test_cpu(cpu_of(dst_rq), p->cpus_ptr)) { ++ struct rq *locked_dst_rq = dst_rq; ++ bool dsp; ++ ++ /* ++ * @p is on a possibly remote @src_rq which we need to lock to ++ * move the task. If dequeue is in progress, it'd be locking ++ * @src_rq and waiting on DISPATCHING, so we can't grab @src_rq ++ * lock while holding DISPATCHING. ++ * ++ * As DISPATCHING guarantees that @p is wholly ours, we can ++ * pretend that we're moving from a DSQ and use the same ++ * mechanism - mark the task under transfer with holding_cpu, ++ * release DISPATCHING and then follow the same protocol. ++ */ ++ p->scx.holding_cpu = raw_smp_processor_id(); ++ ++ /* store_release ensures that dequeue sees the above */ ++ atomic_long_set_release(&p->scx.ops_state, SCX_OPSS_NONE); ++ ++ dispatch_to_local_dsq_lock(rq, src_rq, locked_dst_rq); ++ ++ /* ++ * We don't require the BPF scheduler to avoid dispatching to ++ * offline CPUs mostly for convenience but also because CPUs can ++ * go offline between scx_bpf_dispatch() calls and here. If @p ++ * is destined to an offline CPU, queue it on its current CPU ++ * instead, which should always be safe. As this is an allowed ++ * behavior, don't trigger an ops error. ++ */ ++ if (!scx_rq_online(dst_rq)) ++ dst_rq = src_rq; ++ ++ if (src_rq == dst_rq) { ++ /* ++ * As @p is staying on the same rq, there's no need to ++ * go through the full deactivate/activate cycle. ++ * Optimize by abbreviating the operations in ++ * move_task_to_local_dsq(). ++ */ ++ dsp = p->scx.holding_cpu == raw_smp_processor_id(); ++ if (likely(dsp)) { ++ p->scx.holding_cpu = -1; ++ dispatch_enqueue(&dst_rq->scx.local_dsq, p, ++ enq_flags); ++ } ++ } else { ++ dsp = move_task_to_local_dsq(dst_rq, p, enq_flags); ++ } ++ ++ /* if the destination CPU is idle, wake it up */ ++ if (dsp && sched_class_above(p->sched_class, ++ dst_rq->curr->sched_class)) ++ resched_curr(dst_rq); ++ ++ dispatch_to_local_dsq_unlock(rq, src_rq, locked_dst_rq); ++ ++ return dsp ? DTL_DISPATCHED : DTL_LOST; ++ } ++#endif /* CONFIG_SMP */ ++ ++ scx_ops_error("SCX_DSQ_LOCAL[_ON] verdict target cpu %d not allowed for %s[%d]", ++ cpu_of(dst_rq), p->comm, p->pid); ++ return DTL_INVALID; ++} ++ ++/** ++ * finish_dispatch - Asynchronously finish dispatching a task ++ * @rq: current rq which is locked ++ * @p: task to finish dispatching ++ * @qseq_at_dispatch: qseq when @p started getting dispatched ++ * @dsq_id: destination DSQ ID ++ * @enq_flags: %SCX_ENQ_* ++ * ++ * Dispatching to local DSQs may need to wait for queueing to complete or ++ * require rq lock dancing. As we don't wanna do either while inside ++ * ops.dispatch() to avoid locking order inversion, we split dispatching into ++ * two parts. scx_bpf_dispatch() which is called by ops.dispatch() records the ++ * task and its qseq. Once ops.dispatch() returns, this function is called to ++ * finish up. ++ * ++ * There is no guarantee that @p is still valid for dispatching or even that it ++ * was valid in the first place. Make sure that the task is still owned by the ++ * BPF scheduler and claim the ownership before dispatching. ++ */ ++static void finish_dispatch(struct rq *rq, struct task_struct *p, ++ unsigned long qseq_at_dispatch, ++ u64 dsq_id, u64 enq_flags) ++{ ++ struct scx_dispatch_q *dsq; ++ unsigned long opss; ++ ++ touch_core_sched_dispatch(rq, p); ++retry: ++ /* ++ * No need for _acquire here. @p is accessed only after a successful ++ * try_cmpxchg to DISPATCHING. ++ */ ++ opss = atomic_long_read(&p->scx.ops_state); ++ ++ switch (opss & SCX_OPSS_STATE_MASK) { ++ case SCX_OPSS_DISPATCHING: ++ case SCX_OPSS_NONE: ++ /* someone else already got to it */ ++ return; ++ case SCX_OPSS_QUEUED: ++ /* ++ * If qseq doesn't match, @p has gone through at least one ++ * dispatch/dequeue and re-enqueue cycle between ++ * scx_bpf_dispatch() and here and we have no claim on it. ++ */ ++ if ((opss & SCX_OPSS_QSEQ_MASK) != qseq_at_dispatch) ++ return; ++ ++ /* ++ * While we know @p is accessible, we don't yet have a claim on ++ * it - the BPF scheduler is allowed to dispatch tasks ++ * spuriously and there can be a racing dequeue attempt. Let's ++ * claim @p by atomically transitioning it from QUEUED to ++ * DISPATCHING. ++ */ ++ if (likely(atomic_long_try_cmpxchg(&p->scx.ops_state, &opss, ++ SCX_OPSS_DISPATCHING))) ++ break; ++ goto retry; ++ case SCX_OPSS_QUEUEING: ++ /* ++ * do_enqueue_task() is in the process of transferring the task ++ * to the BPF scheduler while holding @p's rq lock. As we aren't ++ * holding any kernel or BPF resource that the enqueue path may ++ * depend upon, it's safe to wait. ++ */ ++ wait_ops_state(p, opss); ++ goto retry; ++ } ++ ++ BUG_ON(!(p->scx.flags & SCX_TASK_QUEUED)); ++ ++ switch (dispatch_to_local_dsq(rq, dsq_id, p, enq_flags)) { ++ case DTL_DISPATCHED: ++ break; ++ case DTL_LOST: ++ break; ++ case DTL_INVALID: ++ dsq_id = SCX_DSQ_GLOBAL; ++ fallthrough; ++ case DTL_NOT_LOCAL: ++ dsq = find_dsq_for_dispatch(cpu_rq(raw_smp_processor_id()), ++ dsq_id, p); ++ dispatch_enqueue(dsq, p, enq_flags | SCX_ENQ_CLEAR_OPSS); ++ break; ++ } ++} ++ ++static void flush_dispatch_buf(struct rq *rq) ++{ ++ struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx); ++ u32 u; ++ ++ for (u = 0; u < dspc->cursor; u++) { ++ struct scx_dsp_buf_ent *ent = &dspc->buf[u]; ++ ++ finish_dispatch(rq, ent->task, ent->qseq, ent->dsq_id, ++ ent->enq_flags); ++ } ++ ++ dspc->nr_tasks += dspc->cursor; ++ dspc->cursor = 0; ++} ++ ++static int balance_one(struct rq *rq, struct task_struct *prev, bool local) ++{ ++ struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx); ++ bool prev_on_scx = prev->sched_class == &ext_sched_class; ++ int nr_loops = SCX_DSP_MAX_LOOPS; ++ bool has_tasks = false; ++ ++ lockdep_assert_rq_held(rq); ++ rq->scx.flags |= SCX_RQ_IN_BALANCE; ++ ++ if (static_branch_unlikely(&scx_ops_cpu_preempt) && ++ unlikely(rq->scx.cpu_released)) { ++ /* ++ * If the previous sched_class for the current CPU was not SCX, ++ * notify the BPF scheduler that it again has control of the ++ * core. This callback complements ->cpu_release(), which is ++ * emitted in scx_next_task_picked(). ++ */ ++ if (SCX_HAS_OP(cpu_acquire)) ++ SCX_CALL_OP(0, cpu_acquire, cpu_of(rq), NULL); ++ rq->scx.cpu_released = false; ++ } ++ ++ if (prev_on_scx) { ++ WARN_ON_ONCE(local && (prev->scx.flags & SCX_TASK_BAL_KEEP)); ++ update_curr_scx(rq); ++ ++ /* ++ * If @prev is runnable & has slice left, it has priority and ++ * fetching more just increases latency for the fetched tasks. ++ * Tell put_prev_task_scx() to put @prev on local_dsq. If the ++ * BPF scheduler wants to handle this explicitly, it should ++ * implement ->cpu_released(). ++ * ++ * See scx_ops_disable_workfn() for the explanation on the ++ * bypassing test. ++ * ++ * When balancing a remote CPU for core-sched, there won't be a ++ * following put_prev_task_scx() call and we don't own ++ * %SCX_TASK_BAL_KEEP. Instead, pick_task_scx() will test the ++ * same conditions later and pick @rq->curr accordingly. ++ */ ++ if ((prev->scx.flags & SCX_TASK_QUEUED) && ++ prev->scx.slice && !scx_ops_bypassing()) { ++ if (local) ++ prev->scx.flags |= SCX_TASK_BAL_KEEP; ++ goto has_tasks; ++ } ++ } ++ ++ /* if there already are tasks to run, nothing to do */ ++ if (rq->scx.local_dsq.nr) ++ goto has_tasks; ++ ++ if (consume_dispatch_q(rq, &scx_dsq_global)) ++ goto has_tasks; ++ ++ if (!SCX_HAS_OP(dispatch) || scx_ops_bypassing() || !scx_rq_online(rq)) ++ goto out; ++ ++ dspc->rq = rq; ++ ++ /* ++ * The dispatch loop. Because flush_dispatch_buf() may drop the rq lock, ++ * the local DSQ might still end up empty after a successful ++ * ops.dispatch(). If the local DSQ is empty even after ops.dispatch() ++ * produced some tasks, retry. The BPF scheduler may depend on this ++ * looping behavior to simplify its implementation. ++ */ ++ do { ++ dspc->nr_tasks = 0; ++ ++ SCX_CALL_OP(SCX_KF_DISPATCH, dispatch, cpu_of(rq), ++ prev_on_scx ? prev : NULL); ++ ++ flush_dispatch_buf(rq); ++ ++ if (rq->scx.local_dsq.nr) ++ goto has_tasks; ++ if (consume_dispatch_q(rq, &scx_dsq_global)) ++ goto has_tasks; ++ ++ /* ++ * ops.dispatch() can trap us in this loop by repeatedly ++ * dispatching ineligible tasks. Break out once in a while to ++ * allow the watchdog to run. As IRQ can't be enabled in ++ * balance(), we want to complete this scheduling cycle and then ++ * start a new one. IOW, we want to call resched_curr() on the ++ * next, most likely idle, task, not the current one. Use ++ * scx_bpf_kick_cpu() for deferred kicking. ++ */ ++ if (unlikely(!--nr_loops)) { ++ scx_bpf_kick_cpu(cpu_of(rq), 0); ++ break; ++ } ++ } while (dspc->nr_tasks); ++ ++ goto out; ++ ++has_tasks: ++ has_tasks = true; ++out: ++ rq->scx.flags &= ~SCX_RQ_IN_BALANCE; ++ return has_tasks; ++} ++ ++#ifdef CONFIG_SMP ++static int balance_scx(struct rq *rq, struct task_struct *prev, ++ struct rq_flags *rf) ++{ ++ int ret; ++ ++ rq_unpin_lock(rq, rf); ++ ++ ret = balance_one(rq, prev, true); ++ ++#ifdef CONFIG_SCHED_SMT ++ /* ++ * When core-sched is enabled, this ops.balance() call will be followed ++ * by put_prev_scx() and pick_task_scx() on this CPU and pick_task_scx() ++ * on the SMT siblings. Balance the siblings too. ++ */ ++ if (sched_core_enabled(rq)) { ++ const struct cpumask *smt_mask = cpu_smt_mask(cpu_of(rq)); ++ int scpu; ++ ++ for_each_cpu_andnot(scpu, smt_mask, cpumask_of(cpu_of(rq))) { ++ struct rq *srq = cpu_rq(scpu); ++ struct task_struct *sprev = srq->curr; ++ ++ WARN_ON_ONCE(__rq_lockp(rq) != __rq_lockp(srq)); ++ update_rq_clock(srq); ++ balance_one(srq, sprev, false); ++ } ++ } ++#endif ++ rq_repin_lock(rq, rf); ++ ++ return ret; ++} ++#endif ++ ++static void set_next_task_scx(struct rq *rq, struct task_struct *p, bool first) ++{ ++ if (p->scx.flags & SCX_TASK_QUEUED) { ++ /* ++ * Core-sched might decide to execute @p before it is ++ * dispatched. Call ops_dequeue() to notify the BPF scheduler. ++ */ ++ ops_dequeue(p, SCX_DEQ_CORE_SCHED_EXEC); ++ dispatch_dequeue(rq, p); ++ } ++ ++ p->se.exec_start = rq_clock_task(rq); ++ ++ /* see dequeue_task_scx() on why we skip when !QUEUED */ ++ if (SCX_HAS_OP(running) && (p->scx.flags & SCX_TASK_QUEUED)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, running, p); ++ ++ clr_task_runnable(p, true); ++ ++ /* ++ * @p is getting newly scheduled or got kicked after someone updated its ++ * slice. Refresh whether tick can be stopped. See scx_can_stop_tick(). ++ */ ++ if ((p->scx.slice == SCX_SLICE_INF) != ++ (bool)(rq->scx.flags & SCX_RQ_CAN_STOP_TICK)) { ++ if (p->scx.slice == SCX_SLICE_INF) ++ rq->scx.flags |= SCX_RQ_CAN_STOP_TICK; ++ else ++ rq->scx.flags &= ~SCX_RQ_CAN_STOP_TICK; ++ ++ sched_update_tick_dependency(rq); ++ ++ /* ++ * For now, let's refresh the load_avgs just when transitioning ++ * in and out of nohz. In the future, we might want to add a ++ * mechanism which calls the following periodically on ++ * tick-stopped CPUs. ++ */ ++ update_other_load_avgs(rq); ++ } ++} ++ ++static void process_ddsp_deferred_locals(struct rq *rq) ++{ ++ struct task_struct *p, *tmp; ++ ++ lockdep_assert_rq_held(rq); ++ ++ /* ++ * Now that @rq can be unlocked, execute the deferred enqueueing of ++ * tasks directly dispatched to the local DSQs of other CPUs. See ++ * direct_dispatch(). ++ */ ++ list_for_each_entry_safe(p, tmp, &rq->scx.ddsp_deferred_locals, ++ scx.dsq_list.node) { ++ s32 ret; ++ ++ list_del_init(&p->scx.dsq_list.node); ++ ++ ret = dispatch_to_local_dsq(rq, p->scx.ddsp_dsq_id, p, ++ p->scx.ddsp_enq_flags); ++ WARN_ON_ONCE(ret == DTL_NOT_LOCAL); ++ } ++} ++ ++static void put_prev_task_scx(struct rq *rq, struct task_struct *p) ++{ ++#ifndef CONFIG_SMP ++ /* ++ * UP workaround. ++ * ++ * Because SCX may transfer tasks across CPUs during dispatch, dispatch ++ * is performed from its balance operation which isn't called in UP. ++ * Let's work around by calling it from the operations which come right ++ * after. ++ * ++ * 1. If the prev task is on SCX, pick_next_task() calls ++ * .put_prev_task() right after. As .put_prev_task() is also called ++ * from other places, we need to distinguish the calls which can be ++ * done by looking at the previous task's state - if still queued or ++ * dequeued with %SCX_DEQ_SLEEP, the caller must be pick_next_task(). ++ * This case is handled here. ++ * ++ * 2. If the prev task is not on SCX, the first following call into SCX ++ * will be .pick_next_task(), which is covered by calling ++ * balance_scx() from pick_next_task_scx(). ++ * ++ * Note that we can't merge the first case into the second as ++ * balance_scx() must be called before the previous SCX task goes ++ * through put_prev_task_scx(). ++ * ++ * @rq is pinned and can't be unlocked. As UP doesn't transfer tasks ++ * around, balance_one() doesn't need to. ++ */ ++ if (p->scx.flags & (SCX_TASK_QUEUED | SCX_TASK_DEQD_FOR_SLEEP)) ++ balance_one(rq, p, true); ++#endif ++ ++ update_curr_scx(rq); ++ ++ /* see dequeue_task_scx() on why we skip when !QUEUED */ ++ if (SCX_HAS_OP(stopping) && (p->scx.flags & SCX_TASK_QUEUED)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, stopping, p, true); ++ ++ /* ++ * If we're being called from put_prev_task_balance(), balance_scx() may ++ * have decided that @p should keep running. ++ */ ++ if (p->scx.flags & SCX_TASK_BAL_KEEP) { ++ p->scx.flags &= ~SCX_TASK_BAL_KEEP; ++ set_task_runnable(rq, p); ++ dispatch_enqueue(&rq->scx.local_dsq, p, SCX_ENQ_HEAD); ++ return; ++ } ++ ++ if (p->scx.flags & SCX_TASK_QUEUED) { ++ set_task_runnable(rq, p); ++ ++ /* ++ * If @p has slice left and balance_scx() didn't tag it for ++ * keeping, @p is getting preempted by a higher priority ++ * scheduler class or core-sched forcing a different task. Leave ++ * it at the head of the local DSQ. ++ */ ++ if (p->scx.slice && !scx_ops_bypassing()) { ++ dispatch_enqueue(&rq->scx.local_dsq, p, SCX_ENQ_HEAD); ++ return; ++ } ++ ++ /* ++ * If we're in the pick_next_task path, balance_scx() should ++ * have already populated the local DSQ if there are any other ++ * available tasks. If empty, tell ops.enqueue() that @p is the ++ * only one available for this cpu. ops.enqueue() should put it ++ * on the local DSQ so that the subsequent pick_next_task_scx() ++ * can find the task unless it wants to trigger a separate ++ * follow-up scheduling event. ++ */ ++ if (list_empty(&rq->scx.local_dsq.list)) ++ do_enqueue_task(rq, p, SCX_ENQ_LAST, -1); ++ else ++ do_enqueue_task(rq, p, 0, -1); ++ } ++} ++ ++static struct task_struct *first_local_task(struct rq *rq) ++{ ++ return list_first_entry_or_null(&rq->scx.local_dsq.list, ++ struct task_struct, scx.dsq_list.node); ++} ++ ++static struct task_struct *pick_next_task_scx(struct rq *rq) ++{ ++ struct task_struct *p; ++ ++#ifndef CONFIG_SMP ++ /* UP workaround - see the comment at the head of put_prev_task_scx() */ ++ if (unlikely(rq->curr->sched_class != &ext_sched_class)) ++ balance_one(rq, rq->curr, true); ++#endif ++ ++ p = first_local_task(rq); ++ if (!p) ++ return NULL; ++ ++ set_next_task_scx(rq, p, true); ++ ++ if (unlikely(!p->scx.slice)) { ++ if (!scx_ops_bypassing() && !scx_warned_zero_slice) { ++ printk_deferred(KERN_WARNING "sched_ext: %s[%d] has zero slice in pick_next_task_scx()\n", ++ p->comm, p->pid); ++ scx_warned_zero_slice = true; ++ } ++ p->scx.slice = SCX_SLICE_DFL; ++ } ++ ++ return p; ++} ++ ++#ifdef CONFIG_SCHED_CORE ++/** ++ * scx_prio_less - Task ordering for core-sched ++ * @a: task A ++ * @b: task B ++ * ++ * Core-sched is implemented as an additional scheduling layer on top of the ++ * usual sched_class'es and needs to find out the expected task ordering. For ++ * SCX, core-sched calls this function to interrogate the task ordering. ++ * ++ * Unless overridden by ops.core_sched_before(), @p->scx.core_sched_at is used ++ * to implement the default task ordering. The older the timestamp, the higher ++ * prority the task - the global FIFO ordering matching the default scheduling ++ * behavior. ++ * ++ * When ops.core_sched_before() is enabled, @p->scx.core_sched_at is used to ++ * implement FIFO ordering within each local DSQ. See pick_task_scx(). ++ */ ++bool scx_prio_less(const struct task_struct *a, const struct task_struct *b, ++ bool in_fi) ++{ ++ /* ++ * The const qualifiers are dropped from task_struct pointers when ++ * calling ops.core_sched_before(). Accesses are controlled by the ++ * verifier. ++ */ ++ if (SCX_HAS_OP(core_sched_before) && !scx_ops_bypassing()) ++ return SCX_CALL_OP_2TASKS_RET(SCX_KF_REST, core_sched_before, ++ (struct task_struct *)a, ++ (struct task_struct *)b); ++ else ++ return time_after64(a->scx.core_sched_at, b->scx.core_sched_at); ++} ++ ++/** ++ * pick_task_scx - Pick a candidate task for core-sched ++ * @rq: rq to pick the candidate task from ++ * ++ * Core-sched calls this function on each SMT sibling to determine the next ++ * tasks to run on the SMT siblings. balance_one() has been called on all ++ * siblings and put_prev_task_scx() has been called only for the current CPU. ++ * ++ * As put_prev_task_scx() hasn't been called on remote CPUs, we can't just look ++ * at the first task in the local dsq. @rq->curr has to be considered explicitly ++ * to mimic %SCX_TASK_BAL_KEEP. ++ */ ++static struct task_struct *pick_task_scx(struct rq *rq) ++{ ++ struct task_struct *curr = rq->curr; ++ struct task_struct *first = first_local_task(rq); ++ ++ if (curr->scx.flags & SCX_TASK_QUEUED) { ++ /* is curr the only runnable task? */ ++ if (!first) ++ return curr; ++ ++ /* ++ * Does curr trump first? We can always go by core_sched_at for ++ * this comparison as it represents global FIFO ordering when ++ * the default core-sched ordering is used and local-DSQ FIFO ++ * ordering otherwise. ++ * ++ * We can have a task with an earlier timestamp on the DSQ. For ++ * example, when a current task is preempted by a sibling ++ * picking a different cookie, the task would be requeued at the ++ * head of the local DSQ with an earlier timestamp than the ++ * core-sched picked next task. Besides, the BPF scheduler may ++ * dispatch any tasks to the local DSQ anytime. ++ */ ++ if (curr->scx.slice && time_before64(curr->scx.core_sched_at, ++ first->scx.core_sched_at)) ++ return curr; ++ } ++ ++ return first; /* this may be %NULL */ ++} ++#endif /* CONFIG_SCHED_CORE */ ++ ++static enum scx_cpu_preempt_reason ++preempt_reason_from_class(const struct sched_class *class) ++{ ++#ifdef CONFIG_SMP ++ if (class == &stop_sched_class) ++ return SCX_CPU_PREEMPT_STOP; ++#endif ++ if (class == &dl_sched_class) ++ return SCX_CPU_PREEMPT_DL; ++ if (class == &rt_sched_class) ++ return SCX_CPU_PREEMPT_RT; ++ return SCX_CPU_PREEMPT_UNKNOWN; ++} ++ ++static void switch_class_scx(struct rq *rq, struct task_struct *next) ++{ ++ const struct sched_class *next_class = next->sched_class; ++ ++ if (!scx_enabled()) ++ return; ++#ifdef CONFIG_SMP ++ /* ++ * Pairs with the smp_load_acquire() issued by a CPU in ++ * kick_cpus_irq_workfn() who is waiting for this CPU to perform a ++ * resched. ++ */ ++ smp_store_release(&rq->scx.pnt_seq, rq->scx.pnt_seq + 1); ++#endif ++ if (!static_branch_unlikely(&scx_ops_cpu_preempt)) ++ return; ++ ++ /* ++ * The callback is conceptually meant to convey that the CPU is no ++ * longer under the control of SCX. Therefore, don't invoke the callback ++ * if the next class is below SCX (in which case the BPF scheduler has ++ * actively decided not to schedule any tasks on the CPU). ++ */ ++ if (sched_class_above(&ext_sched_class, next_class)) ++ return; ++ ++ /* ++ * At this point we know that SCX was preempted by a higher priority ++ * sched_class, so invoke the ->cpu_release() callback if we have not ++ * done so already. We only send the callback once between SCX being ++ * preempted, and it regaining control of the CPU. ++ * ++ * ->cpu_release() complements ->cpu_acquire(), which is emitted the ++ * next time that balance_scx() is invoked. ++ */ ++ if (!rq->scx.cpu_released) { ++ if (SCX_HAS_OP(cpu_release)) { ++ struct scx_cpu_release_args args = { ++ .reason = preempt_reason_from_class(next_class), ++ .task = next, ++ }; ++ ++ SCX_CALL_OP(SCX_KF_CPU_RELEASE, ++ cpu_release, cpu_of(rq), &args); ++ } ++ rq->scx.cpu_released = true; ++ } ++} ++ ++#ifdef CONFIG_SMP ++ ++static bool test_and_clear_cpu_idle(int cpu) ++{ ++#ifdef CONFIG_SCHED_SMT ++ /* ++ * SMT mask should be cleared whether we can claim @cpu or not. The SMT ++ * cluster is not wholly idle either way. This also prevents ++ * scx_pick_idle_cpu() from getting caught in an infinite loop. ++ */ ++ if (sched_smt_active()) { ++ const struct cpumask *smt = cpu_smt_mask(cpu); ++ ++ /* ++ * If offline, @cpu is not its own sibling and ++ * scx_pick_idle_cpu() can get caught in an infinite loop as ++ * @cpu is never cleared from idle_masks.smt. Ensure that @cpu ++ * is eventually cleared. ++ */ ++ if (cpumask_intersects(smt, idle_masks.smt)) ++ cpumask_andnot(idle_masks.smt, idle_masks.smt, smt); ++ else if (cpumask_test_cpu(cpu, idle_masks.smt)) ++ __cpumask_clear_cpu(cpu, idle_masks.smt); ++ } ++#endif ++ return cpumask_test_and_clear_cpu(cpu, idle_masks.cpu); ++} ++ ++static s32 scx_pick_idle_cpu(const struct cpumask *cpus_allowed, u64 flags) ++{ ++ int cpu; ++ ++retry: ++ if (sched_smt_active()) { ++ cpu = cpumask_any_and_distribute(idle_masks.smt, cpus_allowed); ++ if (cpu < nr_cpu_ids) ++ goto found; ++ ++ if (flags & SCX_PICK_IDLE_CORE) ++ return -EBUSY; ++ } ++ ++ cpu = cpumask_any_and_distribute(idle_masks.cpu, cpus_allowed); ++ if (cpu >= nr_cpu_ids) ++ return -EBUSY; ++ ++found: ++ if (test_and_clear_cpu_idle(cpu)) ++ return cpu; ++ else ++ goto retry; ++} ++ ++static s32 scx_select_cpu_dfl(struct task_struct *p, s32 prev_cpu, ++ u64 wake_flags, bool *found) ++{ ++ s32 cpu; ++ ++ *found = false; ++ ++ if (!static_branch_likely(&scx_builtin_idle_enabled)) { ++ scx_ops_error("built-in idle tracking is disabled"); ++ return prev_cpu; ++ } ++ ++ /* ++ * If WAKE_SYNC, the waker's local DSQ is empty, and the system is ++ * under utilized, wake up @p to the local DSQ of the waker. Checking ++ * only for an empty local DSQ is insufficient as it could give the ++ * wakee an unfair advantage when the system is oversaturated. ++ * Checking only for the presence of idle CPUs is also insufficient as ++ * the local DSQ of the waker could have tasks piled up on it even if ++ * there is an idle core elsewhere on the system. ++ */ ++ cpu = smp_processor_id(); ++ if ((wake_flags & SCX_WAKE_SYNC) && p->nr_cpus_allowed > 1 && ++ !cpumask_empty(idle_masks.cpu) && !(current->flags & PF_EXITING) && ++ cpu_rq(cpu)->scx.local_dsq.nr == 0) { ++ if (cpumask_test_cpu(cpu, p->cpus_ptr)) ++ goto cpu_found; ++ } ++ ++ if (p->nr_cpus_allowed == 1) { ++ if (test_and_clear_cpu_idle(prev_cpu)) { ++ cpu = prev_cpu; ++ goto cpu_found; ++ } else { ++ return prev_cpu; ++ } ++ } ++ ++ /* ++ * If CPU has SMT, any wholly idle CPU is likely a better pick than ++ * partially idle @prev_cpu. ++ */ ++ if (sched_smt_active()) { ++ if (cpumask_test_cpu(prev_cpu, idle_masks.smt) && ++ test_and_clear_cpu_idle(prev_cpu)) { ++ cpu = prev_cpu; ++ goto cpu_found; ++ } ++ ++ cpu = scx_pick_idle_cpu(p->cpus_ptr, SCX_PICK_IDLE_CORE); ++ if (cpu >= 0) ++ goto cpu_found; ++ } ++ ++ if (test_and_clear_cpu_idle(prev_cpu)) { ++ cpu = prev_cpu; ++ goto cpu_found; ++ } ++ ++ cpu = scx_pick_idle_cpu(p->cpus_ptr, 0); ++ if (cpu >= 0) ++ goto cpu_found; ++ ++ return prev_cpu; ++ ++cpu_found: ++ *found = true; ++ return cpu; ++} ++ ++static int select_task_rq_scx(struct task_struct *p, int prev_cpu, int wake_flags) ++{ ++ /* ++ * sched_exec() calls with %WF_EXEC when @p is about to exec(2) as it ++ * can be a good migration opportunity with low cache and memory ++ * footprint. Returning a CPU different than @prev_cpu triggers ++ * immediate rq migration. However, for SCX, as the current rq ++ * association doesn't dictate where the task is going to run, this ++ * doesn't fit well. If necessary, we can later add a dedicated method ++ * which can decide to preempt self to force it through the regular ++ * scheduling path. ++ */ ++ if (unlikely(wake_flags & WF_EXEC)) ++ return prev_cpu; ++ ++ if (SCX_HAS_OP(select_cpu)) { ++ s32 cpu; ++ struct task_struct **ddsp_taskp; ++ ++ ddsp_taskp = this_cpu_ptr(&direct_dispatch_task); ++ WARN_ON_ONCE(*ddsp_taskp); ++ *ddsp_taskp = p; ++ ++ cpu = SCX_CALL_OP_TASK_RET(SCX_KF_ENQUEUE | SCX_KF_SELECT_CPU, ++ select_cpu, p, prev_cpu, wake_flags); ++ *ddsp_taskp = NULL; ++ if (ops_cpu_valid(cpu, "from ops.select_cpu()")) ++ return cpu; ++ else ++ return prev_cpu; ++ } else { ++ bool found; ++ s32 cpu; ++ ++ cpu = scx_select_cpu_dfl(p, prev_cpu, wake_flags, &found); ++ if (found) { ++ p->scx.slice = SCX_SLICE_DFL; ++ p->scx.ddsp_dsq_id = SCX_DSQ_LOCAL; ++ } ++ return cpu; ++ } ++} ++ ++static void task_woken_scx(struct rq *rq, struct task_struct *p) ++{ ++ run_deferred(rq); ++} ++ ++static void set_cpus_allowed_scx(struct task_struct *p, ++ struct affinity_context *ac) ++{ ++ set_cpus_allowed_common(p, ac); ++ ++ /* ++ * The effective cpumask is stored in @p->cpus_ptr which may temporarily ++ * differ from the configured one in @p->cpus_mask. Always tell the bpf ++ * scheduler the effective one. ++ * ++ * Fine-grained memory write control is enforced by BPF making the const ++ * designation pointless. Cast it away when calling the operation. ++ */ ++ if (SCX_HAS_OP(set_cpumask)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, set_cpumask, p, ++ (struct cpumask *)p->cpus_ptr); ++} ++ ++static void reset_idle_masks(void) ++{ ++ /* ++ * Consider all online cpus idle. Should converge to the actual state ++ * quickly. ++ */ ++ cpumask_copy(idle_masks.cpu, cpu_online_mask); ++ cpumask_copy(idle_masks.smt, cpu_online_mask); ++} ++ ++void __scx_update_idle(struct rq *rq, bool idle) ++{ ++ int cpu = cpu_of(rq); ++ ++ if (SCX_HAS_OP(update_idle)) { ++ SCX_CALL_OP(SCX_KF_REST, update_idle, cpu_of(rq), idle); ++ if (!static_branch_unlikely(&scx_builtin_idle_enabled)) ++ return; ++ } ++ ++ if (idle) ++ cpumask_set_cpu(cpu, idle_masks.cpu); ++ else ++ cpumask_clear_cpu(cpu, idle_masks.cpu); ++ ++#ifdef CONFIG_SCHED_SMT ++ if (sched_smt_active()) { ++ const struct cpumask *smt = cpu_smt_mask(cpu); ++ ++ if (idle) { ++ /* ++ * idle_masks.smt handling is racy but that's fine as ++ * it's only for optimization and self-correcting. ++ */ ++ for_each_cpu(cpu, smt) { ++ if (!cpumask_test_cpu(cpu, idle_masks.cpu)) ++ return; ++ } ++ cpumask_or(idle_masks.smt, idle_masks.smt, smt); ++ } else { ++ cpumask_andnot(idle_masks.smt, idle_masks.smt, smt); ++ } ++ } ++#endif ++} ++ ++static void handle_hotplug(struct rq *rq, bool online) ++{ ++ int cpu = cpu_of(rq); ++ ++ atomic_long_inc(&scx_hotplug_seq); ++ ++ if (online && SCX_HAS_OP(cpu_online)) ++ SCX_CALL_OP(SCX_KF_UNLOCKED, cpu_online, cpu); ++ else if (!online && SCX_HAS_OP(cpu_offline)) ++ SCX_CALL_OP(SCX_KF_UNLOCKED, cpu_offline, cpu); ++ else ++ scx_ops_exit(SCX_ECODE_ACT_RESTART | SCX_ECODE_RSN_HOTPLUG, ++ "cpu %d going %s, exiting scheduler", cpu, ++ online ? "online" : "offline"); ++} ++ ++void scx_rq_activate(struct rq *rq) ++{ ++ handle_hotplug(rq, true); ++} ++ ++void scx_rq_deactivate(struct rq *rq) ++{ ++ handle_hotplug(rq, false); ++} ++ ++static void rq_online_scx(struct rq *rq) ++{ ++ rq->scx.flags |= SCX_RQ_ONLINE; ++} ++ ++static void rq_offline_scx(struct rq *rq) ++{ ++ rq->scx.flags &= ~SCX_RQ_ONLINE; ++} ++ ++#else /* CONFIG_SMP */ ++ ++static bool test_and_clear_cpu_idle(int cpu) { return false; } ++static s32 scx_pick_idle_cpu(const struct cpumask *cpus_allowed, u64 flags) { return -EBUSY; } ++static void reset_idle_masks(void) {} ++ ++#endif /* CONFIG_SMP */ ++ ++static bool check_rq_for_timeouts(struct rq *rq) ++{ ++ struct task_struct *p; ++ struct rq_flags rf; ++ bool timed_out = false; ++ ++ rq_lock_irqsave(rq, &rf); ++ list_for_each_entry(p, &rq->scx.runnable_list, scx.runnable_node) { ++ unsigned long last_runnable = p->scx.runnable_at; ++ ++ if (unlikely(time_after(jiffies, ++ last_runnable + scx_watchdog_timeout))) { ++ u32 dur_ms = jiffies_to_msecs(jiffies - last_runnable); ++ ++ scx_ops_error_kind(SCX_EXIT_ERROR_STALL, ++ "%s[%d] failed to run for %u.%03us", ++ p->comm, p->pid, ++ dur_ms / 1000, dur_ms % 1000); ++ timed_out = true; ++ break; ++ } ++ } ++ rq_unlock_irqrestore(rq, &rf); ++ ++ return timed_out; ++} ++ ++static void scx_watchdog_workfn(struct work_struct *work) ++{ ++ int cpu; ++ ++ WRITE_ONCE(scx_watchdog_timestamp, jiffies); ++ ++ for_each_online_cpu(cpu) { ++ if (unlikely(check_rq_for_timeouts(cpu_rq(cpu)))) ++ break; ++ ++ cond_resched(); ++ } ++ queue_delayed_work(system_unbound_wq, to_delayed_work(work), ++ scx_watchdog_timeout / 2); ++} ++ ++void scx_tick(struct rq *rq) ++{ ++ unsigned long last_check; ++ ++ if (!scx_enabled()) ++ return; ++ ++ last_check = READ_ONCE(scx_watchdog_timestamp); ++ if (unlikely(time_after(jiffies, ++ last_check + READ_ONCE(scx_watchdog_timeout)))) { ++ u32 dur_ms = jiffies_to_msecs(jiffies - last_check); ++ ++ scx_ops_error_kind(SCX_EXIT_ERROR_STALL, ++ "watchdog failed to check in for %u.%03us", ++ dur_ms / 1000, dur_ms % 1000); ++ } ++ ++ update_other_load_avgs(rq); ++} ++ ++static void task_tick_scx(struct rq *rq, struct task_struct *curr, int queued) ++{ ++ update_curr_scx(rq); ++ ++ /* ++ * While disabling, always resched and refresh core-sched timestamp as ++ * we can't trust the slice management or ops.core_sched_before(). ++ */ ++ if (scx_ops_bypassing()) { ++ curr->scx.slice = 0; ++ touch_core_sched(rq, curr); ++ } else if (SCX_HAS_OP(tick)) { ++ SCX_CALL_OP(SCX_KF_REST, tick, curr); ++ } ++ ++ if (!curr->scx.slice) ++ resched_curr(rq); ++} ++ ++static enum scx_task_state scx_get_task_state(const struct task_struct *p) ++{ ++ return (p->scx.flags & SCX_TASK_STATE_MASK) >> SCX_TASK_STATE_SHIFT; ++} ++ ++static void scx_set_task_state(struct task_struct *p, enum scx_task_state state) ++{ ++ enum scx_task_state prev_state = scx_get_task_state(p); ++ bool warn = false; ++ ++ BUILD_BUG_ON(SCX_TASK_NR_STATES > (1 << SCX_TASK_STATE_BITS)); ++ ++ switch (state) { ++ case SCX_TASK_NONE: ++ break; ++ case SCX_TASK_INIT: ++ warn = prev_state != SCX_TASK_NONE; ++ break; ++ case SCX_TASK_READY: ++ warn = prev_state == SCX_TASK_NONE; ++ break; ++ case SCX_TASK_ENABLED: ++ warn = prev_state != SCX_TASK_READY; ++ break; ++ default: ++ warn = true; ++ return; ++ } ++ ++ WARN_ONCE(warn, "sched_ext: Invalid task state transition %d -> %d for %s[%d]", ++ prev_state, state, p->comm, p->pid); ++ ++ p->scx.flags &= ~SCX_TASK_STATE_MASK; ++ p->scx.flags |= state << SCX_TASK_STATE_SHIFT; ++} ++ ++static int scx_ops_init_task(struct task_struct *p, struct task_group *tg, bool fork) ++{ ++ int ret; ++ ++ p->scx.disallow = false; ++ ++ if (SCX_HAS_OP(init_task)) { ++ struct scx_init_task_args args = { ++ .fork = fork, ++ }; ++ ++ ret = SCX_CALL_OP_RET(SCX_KF_UNLOCKED, init_task, p, &args); ++ if (unlikely(ret)) { ++ ret = ops_sanitize_err("init_task", ret); ++ return ret; ++ } ++ } ++ ++ scx_set_task_state(p, SCX_TASK_INIT); ++ ++ if (p->scx.disallow) { ++ struct rq *rq; ++ struct rq_flags rf; ++ ++ rq = task_rq_lock(p, &rf); ++ ++ /* ++ * We're either in fork or load path and @p->policy will be ++ * applied right after. Reverting @p->policy here and rejecting ++ * %SCHED_EXT transitions from scx_check_setscheduler() ++ * guarantees that if ops.init_task() sets @p->disallow, @p can ++ * never be in SCX. ++ */ ++ if (p->policy == SCHED_EXT) { ++ p->policy = SCHED_NORMAL; ++ atomic_long_inc(&scx_nr_rejected); ++ } ++ ++ task_rq_unlock(rq, p, &rf); ++ } ++ ++ p->scx.flags |= SCX_TASK_RESET_RUNNABLE_AT; ++ return 0; ++} ++ ++static void scx_ops_enable_task(struct task_struct *p) ++{ ++ u32 weight; ++ ++ lockdep_assert_rq_held(task_rq(p)); ++ ++ /* ++ * Set the weight before calling ops.enable() so that the scheduler ++ * doesn't see a stale value if they inspect the task struct. ++ */ ++ if (task_has_idle_policy(p)) ++ weight = WEIGHT_IDLEPRIO; ++ else ++ weight = sched_prio_to_weight[p->static_prio - MAX_RT_PRIO]; ++ ++ p->scx.weight = sched_weight_to_cgroup(weight); ++ ++ if (SCX_HAS_OP(enable)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, enable, p); ++ scx_set_task_state(p, SCX_TASK_ENABLED); ++ ++ if (SCX_HAS_OP(set_weight)) ++ SCX_CALL_OP(SCX_KF_REST, set_weight, p, p->scx.weight); ++} ++ ++static void scx_ops_disable_task(struct task_struct *p) ++{ ++ lockdep_assert_rq_held(task_rq(p)); ++ WARN_ON_ONCE(scx_get_task_state(p) != SCX_TASK_ENABLED); ++ ++ if (SCX_HAS_OP(disable)) ++ SCX_CALL_OP(SCX_KF_REST, disable, p); ++ scx_set_task_state(p, SCX_TASK_READY); ++} ++ ++static void scx_ops_exit_task(struct task_struct *p) ++{ ++ struct scx_exit_task_args args = { ++ .cancelled = false, ++ }; ++ ++ lockdep_assert_rq_held(task_rq(p)); ++ ++ switch (scx_get_task_state(p)) { ++ case SCX_TASK_NONE: ++ return; ++ case SCX_TASK_INIT: ++ args.cancelled = true; ++ break; ++ case SCX_TASK_READY: ++ break; ++ case SCX_TASK_ENABLED: ++ scx_ops_disable_task(p); ++ break; ++ default: ++ WARN_ON_ONCE(true); ++ return; ++ } ++ ++ if (SCX_HAS_OP(exit_task)) ++ SCX_CALL_OP(SCX_KF_REST, exit_task, p, &args); ++ scx_set_task_state(p, SCX_TASK_NONE); ++} ++ ++void init_scx_entity(struct sched_ext_entity *scx) ++{ ++ /* ++ * init_idle() calls this function again after fork sequence is ++ * complete. Don't touch ->tasks_node as it's already linked. ++ */ ++ memset(scx, 0, offsetof(struct sched_ext_entity, tasks_node)); ++ ++ INIT_LIST_HEAD(&scx->dsq_list.node); ++ RB_CLEAR_NODE(&scx->dsq_priq); ++ scx->sticky_cpu = -1; ++ scx->holding_cpu = -1; ++ INIT_LIST_HEAD(&scx->runnable_node); ++ scx->runnable_at = jiffies; ++ scx->ddsp_dsq_id = SCX_DSQ_INVALID; ++ scx->slice = SCX_SLICE_DFL; ++} ++ ++void scx_pre_fork(struct task_struct *p) ++{ ++ /* ++ * BPF scheduler enable/disable paths want to be able to iterate and ++ * update all tasks which can become complex when racing forks. As ++ * enable/disable are very cold paths, let's use a percpu_rwsem to ++ * exclude forks. ++ */ ++ percpu_down_read(&scx_fork_rwsem); ++} ++ ++int scx_fork(struct task_struct *p) ++{ ++ percpu_rwsem_assert_held(&scx_fork_rwsem); ++ ++ if (scx_enabled()) ++ return scx_ops_init_task(p, task_group(p), true); ++ else ++ return 0; ++} ++ ++void scx_post_fork(struct task_struct *p) ++{ ++ if (scx_enabled()) { ++ scx_set_task_state(p, SCX_TASK_READY); ++ ++ /* ++ * Enable the task immediately if it's running on sched_ext. ++ * Otherwise, it'll be enabled in switching_to_scx() if and ++ * when it's ever configured to run with a SCHED_EXT policy. ++ */ ++ if (p->sched_class == &ext_sched_class) { ++ struct rq_flags rf; ++ struct rq *rq; ++ ++ rq = task_rq_lock(p, &rf); ++ scx_ops_enable_task(p); ++ task_rq_unlock(rq, p, &rf); ++ } ++ } ++ ++ spin_lock_irq(&scx_tasks_lock); ++ list_add_tail(&p->scx.tasks_node, &scx_tasks); ++ spin_unlock_irq(&scx_tasks_lock); ++ ++ percpu_up_read(&scx_fork_rwsem); ++} ++ ++void scx_cancel_fork(struct task_struct *p) ++{ ++ if (scx_enabled()) { ++ struct rq *rq; ++ struct rq_flags rf; ++ ++ rq = task_rq_lock(p, &rf); ++ WARN_ON_ONCE(scx_get_task_state(p) >= SCX_TASK_READY); ++ scx_ops_exit_task(p); ++ task_rq_unlock(rq, p, &rf); ++ } ++ ++ percpu_up_read(&scx_fork_rwsem); ++} ++ ++void sched_ext_free(struct task_struct *p) ++{ ++ unsigned long flags; ++ ++ spin_lock_irqsave(&scx_tasks_lock, flags); ++ list_del_init(&p->scx.tasks_node); ++ spin_unlock_irqrestore(&scx_tasks_lock, flags); ++ ++ /* ++ * @p is off scx_tasks and wholly ours. scx_ops_enable()'s READY -> ++ * ENABLED transitions can't race us. Disable ops for @p. ++ */ ++ if (scx_get_task_state(p) != SCX_TASK_NONE) { ++ struct rq_flags rf; ++ struct rq *rq; ++ ++ rq = task_rq_lock(p, &rf); ++ scx_ops_exit_task(p); ++ task_rq_unlock(rq, p, &rf); ++ } ++} ++ ++static void reweight_task_scx(struct rq *rq, struct task_struct *p, ++ const struct load_weight *lw) ++{ ++ lockdep_assert_rq_held(task_rq(p)); ++ ++ p->scx.weight = sched_weight_to_cgroup(scale_load_down(lw->weight)); ++ if (SCX_HAS_OP(set_weight)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, set_weight, p, p->scx.weight); ++} ++ ++static void prio_changed_scx(struct rq *rq, struct task_struct *p, int oldprio) ++{ ++} ++ ++static void switching_to_scx(struct rq *rq, struct task_struct *p) ++{ ++ scx_ops_enable_task(p); ++ ++ /* ++ * set_cpus_allowed_scx() is not called while @p is associated with a ++ * different scheduler class. Keep the BPF scheduler up-to-date. ++ */ ++ if (SCX_HAS_OP(set_cpumask)) ++ SCX_CALL_OP_TASK(SCX_KF_REST, set_cpumask, p, ++ (struct cpumask *)p->cpus_ptr); ++} ++ ++static void switched_from_scx(struct rq *rq, struct task_struct *p) ++{ ++ scx_ops_disable_task(p); ++} ++ ++static void wakeup_preempt_scx(struct rq *rq, struct task_struct *p,int wake_flags) {} ++static void switched_to_scx(struct rq *rq, struct task_struct *p) {} ++ ++int scx_check_setscheduler(struct task_struct *p, int policy) ++{ ++ lockdep_assert_rq_held(task_rq(p)); ++ ++ /* if disallow, reject transitioning into SCX */ ++ if (scx_enabled() && READ_ONCE(p->scx.disallow) && ++ p->policy != policy && policy == SCHED_EXT) ++ return -EACCES; ++ ++ return 0; ++} ++ ++#ifdef CONFIG_NO_HZ_FULL ++bool scx_can_stop_tick(struct rq *rq) ++{ ++ struct task_struct *p = rq->curr; ++ ++ if (scx_ops_bypassing()) ++ return false; ++ ++ if (p->sched_class != &ext_sched_class) ++ return true; ++ ++ /* ++ * @rq can dispatch from different DSQs, so we can't tell whether it ++ * needs the tick or not by looking at nr_running. Allow stopping ticks ++ * iff the BPF scheduler indicated so. See set_next_task_scx(). ++ */ ++ return rq->scx.flags & SCX_RQ_CAN_STOP_TICK; ++} ++#endif ++ ++/* ++ * Omitted operations: ++ * ++ * - wakeup_preempt: NOOP as it isn't useful in the wakeup path because the task ++ * isn't tied to the CPU at that point. Preemption is implemented by resetting ++ * the victim task's slice to 0 and triggering reschedule on the target CPU. ++ * ++ * - migrate_task_rq: Unnecessary as task to cpu mapping is transient. ++ * ++ * - task_fork/dead: We need fork/dead notifications for all tasks regardless of ++ * their current sched_class. Call them directly from sched core instead. ++ */ ++DEFINE_SCHED_CLASS(ext) = { ++ .enqueue_task = enqueue_task_scx, ++ .dequeue_task = dequeue_task_scx, ++ .yield_task = yield_task_scx, ++ .yield_to_task = yield_to_task_scx, ++ ++ .wakeup_preempt = wakeup_preempt_scx, ++ ++ .pick_next_task = pick_next_task_scx, ++ ++ .put_prev_task = put_prev_task_scx, ++ .set_next_task = set_next_task_scx, ++ ++ .switch_class = switch_class_scx, ++ ++#ifdef CONFIG_SMP ++ .balance = balance_scx, ++ .select_task_rq = select_task_rq_scx, ++ .task_woken = task_woken_scx, ++ .set_cpus_allowed = set_cpus_allowed_scx, ++ ++ .rq_online = rq_online_scx, ++ .rq_offline = rq_offline_scx, ++#endif ++ ++#ifdef CONFIG_SCHED_CORE ++ .pick_task = pick_task_scx, ++#endif ++ ++ .task_tick = task_tick_scx, ++ ++ .switching_to = switching_to_scx, ++ .switched_from = switched_from_scx, ++ .switched_to = switched_to_scx, ++ .reweight_task = reweight_task_scx, ++ .prio_changed = prio_changed_scx, ++ ++ .update_curr = update_curr_scx, ++ ++#ifdef CONFIG_UCLAMP_TASK ++ .uclamp_enabled = 1, ++#endif ++}; ++ ++static void init_dsq(struct scx_dispatch_q *dsq, u64 dsq_id) ++{ ++ memset(dsq, 0, sizeof(*dsq)); ++ ++ raw_spin_lock_init(&dsq->lock); ++ INIT_LIST_HEAD(&dsq->list); ++ dsq->id = dsq_id; ++} ++ ++static struct scx_dispatch_q *create_dsq(u64 dsq_id, int node) ++{ ++ struct scx_dispatch_q *dsq; ++ int ret; ++ ++ if (dsq_id & SCX_DSQ_FLAG_BUILTIN) ++ return ERR_PTR(-EINVAL); ++ ++ dsq = kmalloc_node(sizeof(*dsq), GFP_KERNEL, node); ++ if (!dsq) ++ return ERR_PTR(-ENOMEM); ++ ++ init_dsq(dsq, dsq_id); ++ ++ ret = rhashtable_insert_fast(&dsq_hash, &dsq->hash_node, ++ dsq_hash_params); ++ if (ret) { ++ kfree(dsq); ++ return ERR_PTR(ret); ++ } ++ return dsq; ++} ++ ++static void free_dsq_irq_workfn(struct irq_work *irq_work) ++{ ++ struct llist_node *to_free = llist_del_all(&dsqs_to_free); ++ struct scx_dispatch_q *dsq, *tmp_dsq; ++ ++ llist_for_each_entry_safe(dsq, tmp_dsq, to_free, free_node) ++ kfree_rcu(dsq, rcu); ++} ++ ++static DEFINE_IRQ_WORK(free_dsq_irq_work, free_dsq_irq_workfn); ++ ++static void destroy_dsq(u64 dsq_id) ++{ ++ struct scx_dispatch_q *dsq; ++ unsigned long flags; ++ ++ rcu_read_lock(); ++ ++ dsq = find_user_dsq(dsq_id); ++ if (!dsq) ++ goto out_unlock_rcu; ++ ++ raw_spin_lock_irqsave(&dsq->lock, flags); ++ ++ if (dsq->nr) { ++ scx_ops_error("attempting to destroy in-use dsq 0x%016llx (nr=%u)", ++ dsq->id, dsq->nr); ++ goto out_unlock_dsq; ++ } ++ ++ if (rhashtable_remove_fast(&dsq_hash, &dsq->hash_node, dsq_hash_params)) ++ goto out_unlock_dsq; ++ ++ /* ++ * Mark dead by invalidating ->id to prevent dispatch_enqueue() from ++ * queueing more tasks. As this function can be called from anywhere, ++ * freeing is bounced through an irq work to avoid nesting RCU ++ * operations inside scheduler locks. ++ */ ++ dsq->id = SCX_DSQ_INVALID; ++ llist_add(&dsq->free_node, &dsqs_to_free); ++ irq_work_queue(&free_dsq_irq_work); ++ ++out_unlock_dsq: ++ raw_spin_unlock_irqrestore(&dsq->lock, flags); ++out_unlock_rcu: ++ rcu_read_unlock(); ++} ++ ++ ++/******************************************************************************** ++ * Sysfs interface and ops enable/disable. ++ */ ++ ++#define SCX_ATTR(_name) \ ++ static struct kobj_attribute scx_attr_##_name = { \ ++ .attr = { .name = __stringify(_name), .mode = 0444 }, \ ++ .show = scx_attr_##_name##_show, \ ++ } ++ ++static ssize_t scx_attr_state_show(struct kobject *kobj, ++ struct kobj_attribute *ka, char *buf) ++{ ++ return sysfs_emit(buf, "%s\n", ++ scx_ops_enable_state_str[scx_ops_enable_state()]); ++} ++SCX_ATTR(state); ++ ++static ssize_t scx_attr_switch_all_show(struct kobject *kobj, ++ struct kobj_attribute *ka, char *buf) ++{ ++ return sysfs_emit(buf, "%d\n", READ_ONCE(scx_switching_all)); ++} ++SCX_ATTR(switch_all); ++ ++static ssize_t scx_attr_nr_rejected_show(struct kobject *kobj, ++ struct kobj_attribute *ka, char *buf) ++{ ++ return sysfs_emit(buf, "%ld\n", atomic_long_read(&scx_nr_rejected)); ++} ++SCX_ATTR(nr_rejected); ++ ++static ssize_t scx_attr_hotplug_seq_show(struct kobject *kobj, ++ struct kobj_attribute *ka, char *buf) ++{ ++ return sysfs_emit(buf, "%ld\n", atomic_long_read(&scx_hotplug_seq)); ++} ++SCX_ATTR(hotplug_seq); ++ ++static struct attribute *scx_global_attrs[] = { ++ &scx_attr_state.attr, ++ &scx_attr_switch_all.attr, ++ &scx_attr_nr_rejected.attr, ++ &scx_attr_hotplug_seq.attr, ++ NULL, ++}; ++ ++static const struct attribute_group scx_global_attr_group = { ++ .attrs = scx_global_attrs, ++}; ++ ++static void scx_kobj_release(struct kobject *kobj) ++{ ++ kfree(kobj); ++} ++ ++static ssize_t scx_attr_ops_show(struct kobject *kobj, ++ struct kobj_attribute *ka, char *buf) ++{ ++ return sysfs_emit(buf, "%s\n", scx_ops.name); ++} ++SCX_ATTR(ops); ++ ++static struct attribute *scx_sched_attrs[] = { ++ &scx_attr_ops.attr, ++ NULL, ++}; ++ATTRIBUTE_GROUPS(scx_sched); ++ ++static const struct kobj_type scx_ktype = { ++ .release = scx_kobj_release, ++ .sysfs_ops = &kobj_sysfs_ops, ++ .default_groups = scx_sched_groups, ++}; ++ ++static int scx_uevent(const struct kobject *kobj, struct kobj_uevent_env *env) ++{ ++ return add_uevent_var(env, "SCXOPS=%s", scx_ops.name); ++} ++ ++static const struct kset_uevent_ops scx_uevent_ops = { ++ .uevent = scx_uevent, ++}; ++ ++/* ++ * Used by sched_fork() and __setscheduler_prio() to pick the matching ++ * sched_class. dl/rt are already handled. ++ */ ++bool task_should_scx(struct task_struct *p) ++{ ++ if (!scx_enabled() || ++ unlikely(scx_ops_enable_state() == SCX_OPS_DISABLING)) ++ return false; ++ if (READ_ONCE(scx_switching_all)) ++ return true; ++ return p->policy == SCHED_EXT; ++} ++ ++/** ++ * scx_ops_bypass - [Un]bypass scx_ops and guarantee forward progress ++ * ++ * Bypassing guarantees that all runnable tasks make forward progress without ++ * trusting the BPF scheduler. We can't grab any mutexes or rwsems as they might ++ * be held by tasks that the BPF scheduler is forgetting to run, which ++ * unfortunately also excludes toggling the static branches. ++ * ++ * Let's work around by overriding a couple ops and modifying behaviors based on ++ * the DISABLING state and then cycling the queued tasks through dequeue/enqueue ++ * to force global FIFO scheduling. ++ * ++ * a. ops.enqueue() is ignored and tasks are queued in simple global FIFO order. ++ * ++ * b. ops.dispatch() is ignored. ++ * ++ * c. balance_scx() never sets %SCX_TASK_BAL_KEEP as the slice value can't be ++ * trusted. Whenever a tick triggers, the running task is rotated to the tail ++ * of the queue with core_sched_at touched. ++ * ++ * d. pick_next_task() suppresses zero slice warning. ++ * ++ * e. scx_bpf_kick_cpu() is disabled to avoid irq_work malfunction during PM ++ * operations. ++ * ++ * f. scx_prio_less() reverts to the default core_sched_at order. ++ */ ++static void scx_ops_bypass(bool bypass) ++{ ++ int depth, cpu; ++ ++ if (bypass) { ++ depth = atomic_inc_return(&scx_ops_bypass_depth); ++ WARN_ON_ONCE(depth <= 0); ++ if (depth != 1) ++ return; ++ } else { ++ depth = atomic_dec_return(&scx_ops_bypass_depth); ++ WARN_ON_ONCE(depth < 0); ++ if (depth != 0) ++ return; ++ } ++ ++ /* ++ * We need to guarantee that no tasks are on the BPF scheduler while ++ * bypassing. Either we see enabled or the enable path sees the ++ * increased bypass_depth before moving tasks to SCX. ++ */ ++ if (!scx_enabled()) ++ return; ++ ++ /* ++ * No task property is changing. We just need to make sure all currently ++ * queued tasks are re-queued according to the new scx_ops_bypassing() ++ * state. As an optimization, walk each rq's runnable_list instead of ++ * the scx_tasks list. ++ * ++ * This function can't trust the scheduler and thus can't use ++ * cpus_read_lock(). Walk all possible CPUs instead of online. ++ */ ++ for_each_possible_cpu(cpu) { ++ struct rq *rq = cpu_rq(cpu); ++ struct rq_flags rf; ++ struct task_struct *p, *n; ++ ++ rq_lock_irqsave(rq, &rf); ++ ++ /* ++ * The use of list_for_each_entry_safe_reverse() is required ++ * because each task is going to be removed from and added back ++ * to the runnable_list during iteration. Because they're added ++ * to the tail of the list, safe reverse iteration can still ++ * visit all nodes. ++ */ ++ list_for_each_entry_safe_reverse(p, n, &rq->scx.runnable_list, ++ scx.runnable_node) { ++ struct sched_enq_and_set_ctx ctx; ++ ++ /* cycling deq/enq is enough, see the function comment */ ++ sched_deq_and_put_task(p, DEQUEUE_SAVE | DEQUEUE_MOVE, &ctx); ++ sched_enq_and_set_task(&ctx); ++ } ++ ++ rq_unlock_irqrestore(rq, &rf); ++ ++ /* kick to restore ticks */ ++ resched_cpu(cpu); ++ } ++} ++ ++static void free_exit_info(struct scx_exit_info *ei) ++{ ++ kfree(ei->dump); ++ kfree(ei->msg); ++ kfree(ei->bt); ++ kfree(ei); ++} ++ ++static struct scx_exit_info *alloc_exit_info(size_t exit_dump_len) ++{ ++ struct scx_exit_info *ei; ++ ++ ei = kzalloc(sizeof(*ei), GFP_KERNEL); ++ if (!ei) ++ return NULL; ++ ++ ei->bt = kcalloc(SCX_EXIT_BT_LEN, sizeof(ei->bt[0]), GFP_KERNEL); ++ ei->msg = kzalloc(SCX_EXIT_MSG_LEN, GFP_KERNEL); ++ ei->dump = kzalloc(exit_dump_len, GFP_KERNEL); ++ ++ if (!ei->bt || !ei->msg || !ei->dump) { ++ free_exit_info(ei); ++ return NULL; ++ } ++ ++ return ei; ++} ++ ++static const char *scx_exit_reason(enum scx_exit_kind kind) ++{ ++ switch (kind) { ++ case SCX_EXIT_UNREG: ++ return "Scheduler unregistered from user space"; ++ case SCX_EXIT_UNREG_BPF: ++ return "Scheduler unregistered from BPF"; ++ case SCX_EXIT_UNREG_KERN: ++ return "Scheduler unregistered from the main kernel"; ++ case SCX_EXIT_SYSRQ: ++ return "disabled by sysrq-S"; ++ case SCX_EXIT_ERROR: ++ return "runtime error"; ++ case SCX_EXIT_ERROR_BPF: ++ return "scx_bpf_error"; ++ case SCX_EXIT_ERROR_STALL: ++ return "runnable task stall"; ++ default: ++ return ""; ++ } ++} ++ ++static void scx_ops_disable_workfn(struct kthread_work *work) ++{ ++ struct scx_exit_info *ei = scx_exit_info; ++ struct scx_task_iter sti; ++ struct task_struct *p; ++ struct rhashtable_iter rht_iter; ++ struct scx_dispatch_q *dsq; ++ int i, kind; ++ ++ kind = atomic_read(&scx_exit_kind); ++ while (true) { ++ /* ++ * NONE indicates that a new scx_ops has been registered since ++ * disable was scheduled - don't kill the new ops. DONE ++ * indicates that the ops has already been disabled. ++ */ ++ if (kind == SCX_EXIT_NONE || kind == SCX_EXIT_DONE) ++ return; ++ if (atomic_try_cmpxchg(&scx_exit_kind, &kind, SCX_EXIT_DONE)) ++ break; ++ } ++ ei->kind = kind; ++ ei->reason = scx_exit_reason(ei->kind); ++ ++ /* guarantee forward progress by bypassing scx_ops */ ++ scx_ops_bypass(true); ++ ++ switch (scx_ops_set_enable_state(SCX_OPS_DISABLING)) { ++ case SCX_OPS_DISABLING: ++ WARN_ONCE(true, "sched_ext: duplicate disabling instance?"); ++ break; ++ case SCX_OPS_DISABLED: ++ pr_warn("sched_ext: ops error detected without ops (%s)\n", ++ scx_exit_info->msg); ++ WARN_ON_ONCE(scx_ops_set_enable_state(SCX_OPS_DISABLED) != ++ SCX_OPS_DISABLING); ++ goto done; ++ default: ++ break; ++ } ++ ++ /* ++ * Here, every runnable task is guaranteed to make forward progress and ++ * we can safely use blocking synchronization constructs. Actually ++ * disable ops. ++ */ ++ mutex_lock(&scx_ops_enable_mutex); ++ ++ static_branch_disable(&__scx_switched_all); ++ WRITE_ONCE(scx_switching_all, false); ++ ++ /* ++ * Avoid racing against fork. See scx_ops_enable() for explanation on ++ * the locking order. ++ */ ++ percpu_down_write(&scx_fork_rwsem); ++ cpus_read_lock(); ++ ++ spin_lock_irq(&scx_tasks_lock); ++ scx_task_iter_init(&sti); ++ /* ++ * Invoke scx_ops_exit_task() on all non-idle tasks, including ++ * TASK_DEAD tasks. Because dead tasks may have a nonzero refcount, ++ * we may not have invoked sched_ext_free() on them by the time a ++ * scheduler is disabled. We must therefore exit the task here, or we'd ++ * fail to invoke ops.exit_task(), as the scheduler will have been ++ * unloaded by the time the task is subsequently exited on the ++ * sched_ext_free() path. ++ */ ++ while ((p = scx_task_iter_next_locked(&sti, true))) { ++ const struct sched_class *old_class = p->sched_class; ++ struct sched_enq_and_set_ctx ctx; ++ ++ if (READ_ONCE(p->__state) != TASK_DEAD) { ++ sched_deq_and_put_task(p, DEQUEUE_SAVE | DEQUEUE_MOVE, ++ &ctx); ++ ++ p->scx.slice = min_t(u64, p->scx.slice, SCX_SLICE_DFL); ++ __setscheduler_prio(p, p->prio); ++ check_class_changing(task_rq(p), p, old_class); ++ ++ sched_enq_and_set_task(&ctx); ++ ++ check_class_changed(task_rq(p), p, old_class, p->prio); ++ } ++ scx_ops_exit_task(p); ++ } ++ scx_task_iter_exit(&sti); ++ spin_unlock_irq(&scx_tasks_lock); ++ ++ /* no task is on scx, turn off all the switches and flush in-progress calls */ ++ static_branch_disable_cpuslocked(&__scx_ops_enabled); ++ for (i = SCX_OPI_BEGIN; i < SCX_OPI_END; i++) ++ static_branch_disable_cpuslocked(&scx_has_op[i]); ++ static_branch_disable_cpuslocked(&scx_ops_enq_last); ++ static_branch_disable_cpuslocked(&scx_ops_enq_exiting); ++ static_branch_disable_cpuslocked(&scx_ops_cpu_preempt); ++ static_branch_disable_cpuslocked(&scx_builtin_idle_enabled); ++ synchronize_rcu(); ++ ++ cpus_read_unlock(); ++ percpu_up_write(&scx_fork_rwsem); ++ ++ if (ei->kind >= SCX_EXIT_ERROR) { ++ printk(KERN_ERR "sched_ext: BPF scheduler \"%s\" errored, disabling\n", scx_ops.name); ++ ++ if (ei->msg[0] == '\0') ++ printk(KERN_ERR "sched_ext: %s\n", ei->reason); ++ else ++ printk(KERN_ERR "sched_ext: %s (%s)\n", ei->reason, ei->msg); ++ ++ stack_trace_print(ei->bt, ei->bt_len, 2); ++ } ++ ++ if (scx_ops.exit) ++ SCX_CALL_OP(SCX_KF_UNLOCKED, exit, ei); ++ ++ cancel_delayed_work_sync(&scx_watchdog_work); ++ ++ /* ++ * Delete the kobject from the hierarchy eagerly in addition to just ++ * dropping a reference. Otherwise, if the object is deleted ++ * asynchronously, sysfs could observe an object of the same name still ++ * in the hierarchy when another scheduler is loaded. ++ */ ++ kobject_del(scx_root_kobj); ++ kobject_put(scx_root_kobj); ++ scx_root_kobj = NULL; ++ ++ memset(&scx_ops, 0, sizeof(scx_ops)); ++ ++ rhashtable_walk_enter(&dsq_hash, &rht_iter); ++ do { ++ rhashtable_walk_start(&rht_iter); ++ ++ while ((dsq = rhashtable_walk_next(&rht_iter)) && !IS_ERR(dsq)) ++ destroy_dsq(dsq->id); ++ ++ rhashtable_walk_stop(&rht_iter); ++ } while (dsq == ERR_PTR(-EAGAIN)); ++ rhashtable_walk_exit(&rht_iter); ++ ++ free_percpu(scx_dsp_ctx); ++ scx_dsp_ctx = NULL; ++ scx_dsp_max_batch = 0; ++ ++ free_exit_info(scx_exit_info); ++ scx_exit_info = NULL; ++ ++ mutex_unlock(&scx_ops_enable_mutex); ++ ++ WARN_ON_ONCE(scx_ops_set_enable_state(SCX_OPS_DISABLED) != ++ SCX_OPS_DISABLING); ++done: ++ scx_ops_bypass(false); ++} ++ ++static DEFINE_KTHREAD_WORK(scx_ops_disable_work, scx_ops_disable_workfn); ++ ++static void schedule_scx_ops_disable_work(void) ++{ ++ struct kthread_worker *helper = READ_ONCE(scx_ops_helper); ++ ++ /* ++ * We may be called spuriously before the first bpf_sched_ext_reg(). If ++ * scx_ops_helper isn't set up yet, there's nothing to do. ++ */ ++ if (helper) ++ kthread_queue_work(helper, &scx_ops_disable_work); ++} ++ ++static void scx_ops_disable(enum scx_exit_kind kind) ++{ ++ int none = SCX_EXIT_NONE; ++ ++ if (WARN_ON_ONCE(kind == SCX_EXIT_NONE || kind == SCX_EXIT_DONE)) ++ kind = SCX_EXIT_ERROR; ++ ++ atomic_try_cmpxchg(&scx_exit_kind, &none, kind); ++ ++ schedule_scx_ops_disable_work(); ++} ++ ++static void dump_newline(struct seq_buf *s) ++{ ++ trace_sched_ext_dump(""); ++ ++ /* @s may be zero sized and seq_buf triggers WARN if so */ ++ if (s->size) ++ seq_buf_putc(s, '\n'); ++} ++ ++static __printf(2, 3) void dump_line(struct seq_buf *s, const char *fmt, ...) ++{ ++ va_list args; ++ ++#ifdef CONFIG_TRACEPOINTS ++ if (trace_sched_ext_dump_enabled()) { ++ /* protected by scx_dump_state()::dump_lock */ ++ static char line_buf[SCX_EXIT_MSG_LEN]; ++ ++ va_start(args, fmt); ++ vscnprintf(line_buf, sizeof(line_buf), fmt, args); ++ va_end(args); ++ ++ trace_sched_ext_dump(line_buf); ++ } ++#endif ++ /* @s may be zero sized and seq_buf triggers WARN if so */ ++ if (s->size) { ++ va_start(args, fmt); ++ seq_buf_vprintf(s, fmt, args); ++ va_end(args); ++ ++ seq_buf_putc(s, '\n'); ++ } ++} ++ ++static void dump_stack_trace(struct seq_buf *s, const char *prefix, ++ const unsigned long *bt, unsigned int len) ++{ ++ unsigned int i; ++ ++ for (i = 0; i < len; i++) ++ dump_line(s, "%s%pS", prefix, (void *)bt[i]); ++} ++ ++static void ops_dump_init(struct seq_buf *s, const char *prefix) ++{ ++ struct scx_dump_data *dd = &scx_dump_data; ++ ++ lockdep_assert_irqs_disabled(); ++ ++ dd->cpu = smp_processor_id(); /* allow scx_bpf_dump() */ ++ dd->first = true; ++ dd->cursor = 0; ++ dd->s = s; ++ dd->prefix = prefix; ++} ++ ++static void ops_dump_flush(void) ++{ ++ struct scx_dump_data *dd = &scx_dump_data; ++ char *line = dd->buf.line; ++ ++ if (!dd->cursor) ++ return; ++ ++ /* ++ * There's something to flush and this is the first line. Insert a blank ++ * line to distinguish ops dump. ++ */ ++ if (dd->first) { ++ dump_newline(dd->s); ++ dd->first = false; ++ } ++ ++ /* ++ * There may be multiple lines in $line. Scan and emit each line ++ * separately. ++ */ ++ while (true) { ++ char *end = line; ++ char c; ++ ++ while (*end != '\n' && *end != '\0') ++ end++; ++ ++ /* ++ * If $line overflowed, it may not have newline at the end. ++ * Always emit with a newline. ++ */ ++ c = *end; ++ *end = '\0'; ++ dump_line(dd->s, "%s%s", dd->prefix, line); ++ if (c == '\0') ++ break; ++ ++ /* move to the next line */ ++ end++; ++ if (*end == '\0') ++ break; ++ line = end; ++ } ++ ++ dd->cursor = 0; ++} ++ ++static void ops_dump_exit(void) ++{ ++ ops_dump_flush(); ++ scx_dump_data.cpu = -1; ++} ++ ++static void scx_dump_task(struct seq_buf *s, struct scx_dump_ctx *dctx, ++ struct task_struct *p, char marker) ++{ ++ static unsigned long bt[SCX_EXIT_BT_LEN]; ++ char dsq_id_buf[19] = "(n/a)"; ++ unsigned long ops_state = atomic_long_read(&p->scx.ops_state); ++ unsigned int bt_len; ++ ++ if (p->scx.dsq) ++ scnprintf(dsq_id_buf, sizeof(dsq_id_buf), "0x%llx", ++ (unsigned long long)p->scx.dsq->id); ++ ++ dump_newline(s); ++ dump_line(s, " %c%c %s[%d] %+ldms", ++ marker, task_state_to_char(p), p->comm, p->pid, ++ jiffies_delta_msecs(p->scx.runnable_at, dctx->at_jiffies)); ++ dump_line(s, " scx_state/flags=%u/0x%x dsq_flags=0x%x ops_state/qseq=%lu/%lu", ++ scx_get_task_state(p), p->scx.flags & ~SCX_TASK_STATE_MASK, ++ p->scx.dsq_flags, ops_state & SCX_OPSS_STATE_MASK, ++ ops_state >> SCX_OPSS_QSEQ_SHIFT); ++ dump_line(s, " sticky/holding_cpu=%d/%d dsq_id=%s dsq_vtime=%llu", ++ p->scx.sticky_cpu, p->scx.holding_cpu, dsq_id_buf, ++ p->scx.dsq_vtime); ++ dump_line(s, " cpus=%*pb", cpumask_pr_args(p->cpus_ptr)); ++ ++ if (SCX_HAS_OP(dump_task)) { ++ ops_dump_init(s, " "); ++ SCX_CALL_OP(SCX_KF_REST, dump_task, dctx, p); ++ ops_dump_exit(); ++ } ++ ++ bt_len = stack_trace_save_tsk(p, bt, SCX_EXIT_BT_LEN, 1); ++ if (bt_len) { ++ dump_newline(s); ++ dump_stack_trace(s, " ", bt, bt_len); ++ } ++} ++ ++static void scx_dump_state(struct scx_exit_info *ei, size_t dump_len) ++{ ++ static DEFINE_SPINLOCK(dump_lock); ++ static const char trunc_marker[] = "\n\n~~~~ TRUNCATED ~~~~\n"; ++ struct scx_dump_ctx dctx = { ++ .kind = ei->kind, ++ .exit_code = ei->exit_code, ++ .reason = ei->reason, ++ .at_ns = ktime_get_ns(), ++ .at_jiffies = jiffies, ++ }; ++ struct seq_buf s; ++ unsigned long flags; ++ char *buf; ++ int cpu; ++ ++ spin_lock_irqsave(&dump_lock, flags); ++ ++ seq_buf_init(&s, ei->dump, dump_len); ++ ++ if (ei->kind == SCX_EXIT_NONE) { ++ dump_line(&s, "Debug dump triggered by %s", ei->reason); ++ } else { ++ dump_line(&s, "%s[%d] triggered exit kind %d:", ++ current->comm, current->pid, ei->kind); ++ dump_line(&s, " %s (%s)", ei->reason, ei->msg); ++ dump_newline(&s); ++ dump_line(&s, "Backtrace:"); ++ dump_stack_trace(&s, " ", ei->bt, ei->bt_len); ++ } ++ ++ if (SCX_HAS_OP(dump)) { ++ ops_dump_init(&s, ""); ++ SCX_CALL_OP(SCX_KF_UNLOCKED, dump, &dctx); ++ ops_dump_exit(); ++ } ++ ++ dump_newline(&s); ++ dump_line(&s, "CPU states"); ++ dump_line(&s, "----------"); ++ ++ for_each_possible_cpu(cpu) { ++ struct rq *rq = cpu_rq(cpu); ++ struct rq_flags rf; ++ struct task_struct *p; ++ struct seq_buf ns; ++ size_t avail, used; ++ bool idle; ++ ++ rq_lock(rq, &rf); ++ ++ idle = list_empty(&rq->scx.runnable_list) && ++ rq->curr->sched_class == &idle_sched_class; ++ ++ if (idle && !SCX_HAS_OP(dump_cpu)) ++ goto next; ++ ++ /* ++ * We don't yet know whether ops.dump_cpu() will produce output ++ * and we may want to skip the default CPU dump if it doesn't. ++ * Use a nested seq_buf to generate the standard dump so that we ++ * can decide whether to commit later. ++ */ ++ avail = seq_buf_get_buf(&s, &buf); ++ seq_buf_init(&ns, buf, avail); ++ ++ dump_newline(&ns); ++ dump_line(&ns, "CPU %-4d: nr_run=%u flags=0x%x cpu_rel=%d ops_qseq=%lu pnt_seq=%lu", ++ cpu, rq->scx.nr_running, rq->scx.flags, ++ rq->scx.cpu_released, rq->scx.ops_qseq, ++ rq->scx.pnt_seq); ++ dump_line(&ns, " curr=%s[%d] class=%ps", ++ rq->curr->comm, rq->curr->pid, ++ rq->curr->sched_class); ++ if (!cpumask_empty(rq->scx.cpus_to_kick)) ++ dump_line(&ns, " cpus_to_kick : %*pb", ++ cpumask_pr_args(rq->scx.cpus_to_kick)); ++ if (!cpumask_empty(rq->scx.cpus_to_kick_if_idle)) ++ dump_line(&ns, " idle_to_kick : %*pb", ++ cpumask_pr_args(rq->scx.cpus_to_kick_if_idle)); ++ if (!cpumask_empty(rq->scx.cpus_to_preempt)) ++ dump_line(&ns, " cpus_to_preempt: %*pb", ++ cpumask_pr_args(rq->scx.cpus_to_preempt)); ++ if (!cpumask_empty(rq->scx.cpus_to_wait)) ++ dump_line(&ns, " cpus_to_wait : %*pb", ++ cpumask_pr_args(rq->scx.cpus_to_wait)); ++ ++ used = seq_buf_used(&ns); ++ if (SCX_HAS_OP(dump_cpu)) { ++ ops_dump_init(&ns, " "); ++ SCX_CALL_OP(SCX_KF_REST, dump_cpu, &dctx, cpu, idle); ++ ops_dump_exit(); ++ } ++ ++ /* ++ * If idle && nothing generated by ops.dump_cpu(), there's ++ * nothing interesting. Skip. ++ */ ++ if (idle && used == seq_buf_used(&ns)) ++ goto next; ++ ++ /* ++ * $s may already have overflowed when $ns was created. If so, ++ * calling commit on it will trigger BUG. ++ */ ++ if (avail) { ++ seq_buf_commit(&s, seq_buf_used(&ns)); ++ if (seq_buf_has_overflowed(&ns)) ++ seq_buf_set_overflow(&s); ++ } ++ ++ if (rq->curr->sched_class == &ext_sched_class) ++ scx_dump_task(&s, &dctx, rq->curr, '*'); ++ ++ list_for_each_entry(p, &rq->scx.runnable_list, scx.runnable_node) ++ scx_dump_task(&s, &dctx, p, ' '); ++ next: ++ rq_unlock(rq, &rf); ++ } ++ ++ if (seq_buf_has_overflowed(&s) && dump_len >= sizeof(trunc_marker)) ++ memcpy(ei->dump + dump_len - sizeof(trunc_marker), ++ trunc_marker, sizeof(trunc_marker)); ++ ++ spin_unlock_irqrestore(&dump_lock, flags); ++} ++ ++static void scx_ops_error_irq_workfn(struct irq_work *irq_work) ++{ ++ struct scx_exit_info *ei = scx_exit_info; ++ ++ if (ei->kind >= SCX_EXIT_ERROR) ++ scx_dump_state(ei, scx_ops.exit_dump_len); ++ ++ schedule_scx_ops_disable_work(); ++} ++ ++static DEFINE_IRQ_WORK(scx_ops_error_irq_work, scx_ops_error_irq_workfn); ++ ++static __printf(3, 4) void scx_ops_exit_kind(enum scx_exit_kind kind, ++ s64 exit_code, ++ const char *fmt, ...) ++{ ++ struct scx_exit_info *ei = scx_exit_info; ++ int none = SCX_EXIT_NONE; ++ va_list args; ++ ++ if (!atomic_try_cmpxchg(&scx_exit_kind, &none, kind)) ++ return; ++ ++ ei->exit_code = exit_code; ++ ++ if (kind >= SCX_EXIT_ERROR) ++ ei->bt_len = stack_trace_save(ei->bt, SCX_EXIT_BT_LEN, 1); ++ ++ va_start(args, fmt); ++ vscnprintf(ei->msg, SCX_EXIT_MSG_LEN, fmt, args); ++ va_end(args); ++ ++ /* ++ * Set ei->kind and ->reason for scx_dump_state(). They'll be set again ++ * in scx_ops_disable_workfn(). ++ */ ++ ei->kind = kind; ++ ei->reason = scx_exit_reason(ei->kind); ++ ++ irq_work_queue(&scx_ops_error_irq_work); ++} ++ ++static struct kthread_worker *scx_create_rt_helper(const char *name) ++{ ++ struct kthread_worker *helper; ++ ++ helper = kthread_create_worker(0, name); ++ if (helper) ++ sched_set_fifo(helper->task); ++ return helper; ++} ++ ++static void check_hotplug_seq(const struct sched_ext_ops *ops) ++{ ++ unsigned long long global_hotplug_seq; ++ ++ /* ++ * If a hotplug event has occurred between when a scheduler was ++ * initialized, and when we were able to attach, exit and notify user ++ * space about it. ++ */ ++ if (ops->hotplug_seq) { ++ global_hotplug_seq = atomic_long_read(&scx_hotplug_seq); ++ if (ops->hotplug_seq != global_hotplug_seq) { ++ scx_ops_exit(SCX_ECODE_ACT_RESTART | SCX_ECODE_RSN_HOTPLUG, ++ "expected hotplug seq %llu did not match actual %llu", ++ ops->hotplug_seq, global_hotplug_seq); ++ } ++ } ++} ++ ++static int validate_ops(const struct sched_ext_ops *ops) ++{ ++ /* ++ * It doesn't make sense to specify the SCX_OPS_ENQ_LAST flag if the ++ * ops.enqueue() callback isn't implemented. ++ */ ++ if ((ops->flags & SCX_OPS_ENQ_LAST) && !ops->enqueue) { ++ scx_ops_error("SCX_OPS_ENQ_LAST requires ops.enqueue() to be implemented"); ++ return -EINVAL; ++ } ++ ++ return 0; ++} ++ ++static int scx_ops_enable(struct sched_ext_ops *ops) ++{ ++ struct scx_task_iter sti; ++ struct task_struct *p; ++ unsigned long timeout; ++ int i, cpu, ret; ++ ++ if (!cpumask_equal(housekeeping_cpumask(HK_TYPE_DOMAIN), ++ cpu_possible_mask)) { ++ pr_err("sched_ext: Not compatible with \"isolcpus=\" domain isolation"); ++ return -EINVAL; ++ } ++ ++ mutex_lock(&scx_ops_enable_mutex); ++ ++ if (!scx_ops_helper) { ++ WRITE_ONCE(scx_ops_helper, ++ scx_create_rt_helper("sched_ext_ops_helper")); ++ if (!scx_ops_helper) { ++ ret = -ENOMEM; ++ goto err_unlock; ++ } ++ } ++ ++ if (scx_ops_enable_state() != SCX_OPS_DISABLED) { ++ ret = -EBUSY; ++ goto err_unlock; ++ } ++ ++ scx_root_kobj = kzalloc(sizeof(*scx_root_kobj), GFP_KERNEL); ++ if (!scx_root_kobj) { ++ ret = -ENOMEM; ++ goto err_unlock; ++ } ++ ++ scx_root_kobj->kset = scx_kset; ++ ret = kobject_init_and_add(scx_root_kobj, &scx_ktype, NULL, "root"); ++ if (ret < 0) ++ goto err; ++ ++ scx_exit_info = alloc_exit_info(ops->exit_dump_len); ++ if (!scx_exit_info) { ++ ret = -ENOMEM; ++ goto err_del; ++ } ++ ++ /* ++ * Set scx_ops, transition to PREPPING and clear exit info to arm the ++ * disable path. Failure triggers full disabling from here on. ++ */ ++ scx_ops = *ops; ++ ++ WARN_ON_ONCE(scx_ops_set_enable_state(SCX_OPS_PREPPING) != ++ SCX_OPS_DISABLED); ++ ++ atomic_set(&scx_exit_kind, SCX_EXIT_NONE); ++ scx_warned_zero_slice = false; ++ ++ atomic_long_set(&scx_nr_rejected, 0); ++ ++ for_each_possible_cpu(cpu) ++ cpu_rq(cpu)->scx.cpuperf_target = SCX_CPUPERF_ONE; ++ ++ /* ++ * Keep CPUs stable during enable so that the BPF scheduler can track ++ * online CPUs by watching ->on/offline_cpu() after ->init(). ++ */ ++ cpus_read_lock(); ++ ++ if (scx_ops.init) { ++ ret = SCX_CALL_OP_RET(SCX_KF_UNLOCKED, init); ++ if (ret) { ++ ret = ops_sanitize_err("init", ret); ++ goto err_disable_unlock_cpus; ++ } ++ } ++ ++ for (i = SCX_OPI_CPU_HOTPLUG_BEGIN; i < SCX_OPI_CPU_HOTPLUG_END; i++) ++ if (((void (**)(void))ops)[i]) ++ static_branch_enable_cpuslocked(&scx_has_op[i]); ++ ++ cpus_read_unlock(); ++ ++ ret = validate_ops(ops); ++ if (ret) ++ goto err_disable; ++ ++ WARN_ON_ONCE(scx_dsp_ctx); ++ scx_dsp_max_batch = ops->dispatch_max_batch ?: SCX_DSP_DFL_MAX_BATCH; ++ scx_dsp_ctx = __alloc_percpu(struct_size_t(struct scx_dsp_ctx, buf, ++ scx_dsp_max_batch), ++ __alignof__(struct scx_dsp_ctx)); ++ if (!scx_dsp_ctx) { ++ ret = -ENOMEM; ++ goto err_disable; ++ } ++ ++ if (ops->timeout_ms) ++ timeout = msecs_to_jiffies(ops->timeout_ms); ++ else ++ timeout = SCX_WATCHDOG_MAX_TIMEOUT; ++ ++ WRITE_ONCE(scx_watchdog_timeout, timeout); ++ WRITE_ONCE(scx_watchdog_timestamp, jiffies); ++ queue_delayed_work(system_unbound_wq, &scx_watchdog_work, ++ scx_watchdog_timeout / 2); ++ ++ /* ++ * Lock out forks before opening the floodgate so that they don't wander ++ * into the operations prematurely. ++ * ++ * We don't need to keep the CPUs stable but grab cpus_read_lock() to ++ * ease future locking changes for cgroup suport. ++ * ++ * Note that cpu_hotplug_lock must nest inside scx_fork_rwsem due to the ++ * following dependency chain: ++ * ++ * scx_fork_rwsem --> pernet_ops_rwsem --> cpu_hotplug_lock ++ */ ++ percpu_down_write(&scx_fork_rwsem); ++ cpus_read_lock(); ++ ++ check_hotplug_seq(ops); ++ ++ for (i = SCX_OPI_NORMAL_BEGIN; i < SCX_OPI_NORMAL_END; i++) ++ if (((void (**)(void))ops)[i]) ++ static_branch_enable_cpuslocked(&scx_has_op[i]); ++ ++ if (ops->flags & SCX_OPS_ENQ_LAST) ++ static_branch_enable_cpuslocked(&scx_ops_enq_last); ++ ++ if (ops->flags & SCX_OPS_ENQ_EXITING) ++ static_branch_enable_cpuslocked(&scx_ops_enq_exiting); ++ if (scx_ops.cpu_acquire || scx_ops.cpu_release) ++ static_branch_enable_cpuslocked(&scx_ops_cpu_preempt); ++ ++ if (!ops->update_idle || (ops->flags & SCX_OPS_KEEP_BUILTIN_IDLE)) { ++ reset_idle_masks(); ++ static_branch_enable_cpuslocked(&scx_builtin_idle_enabled); ++ } else { ++ static_branch_disable_cpuslocked(&scx_builtin_idle_enabled); ++ } ++ ++ static_branch_enable_cpuslocked(&__scx_ops_enabled); ++ ++ /* ++ * Enable ops for every task. Fork is excluded by scx_fork_rwsem ++ * preventing new tasks from being added. No need to exclude tasks ++ * leaving as sched_ext_free() can handle both prepped and enabled ++ * tasks. Prep all tasks first and then enable them with preemption ++ * disabled. ++ */ ++ spin_lock_irq(&scx_tasks_lock); ++ ++ scx_task_iter_init(&sti); ++ while ((p = scx_task_iter_next_locked(&sti, false))) { ++ get_task_struct(p); ++ scx_task_iter_rq_unlock(&sti); ++ spin_unlock_irq(&scx_tasks_lock); ++ ++ ret = scx_ops_init_task(p, task_group(p), false); ++ if (ret) { ++ put_task_struct(p); ++ spin_lock_irq(&scx_tasks_lock); ++ scx_task_iter_exit(&sti); ++ spin_unlock_irq(&scx_tasks_lock); ++ pr_err("sched_ext: ops.init_task() failed (%d) for %s[%d] while loading\n", ++ ret, p->comm, p->pid); ++ goto err_disable_unlock_all; ++ } ++ ++ put_task_struct(p); ++ spin_lock_irq(&scx_tasks_lock); ++ } ++ scx_task_iter_exit(&sti); ++ ++ /* ++ * All tasks are prepped but are still ops-disabled. Ensure that ++ * %current can't be scheduled out and switch everyone. ++ * preempt_disable() is necessary because we can't guarantee that ++ * %current won't be starved if scheduled out while switching. ++ */ ++ preempt_disable(); ++ ++ /* ++ * From here on, the disable path must assume that tasks have ops ++ * enabled and need to be recovered. ++ * ++ * Transition to ENABLING fails iff the BPF scheduler has already ++ * triggered scx_bpf_error(). Returning an error code here would lose ++ * the recorded error information. Exit indicating success so that the ++ * error is notified through ops.exit() with all the details. ++ */ ++ if (!scx_ops_tryset_enable_state(SCX_OPS_ENABLING, SCX_OPS_PREPPING)) { ++ preempt_enable(); ++ spin_unlock_irq(&scx_tasks_lock); ++ WARN_ON_ONCE(atomic_read(&scx_exit_kind) == SCX_EXIT_NONE); ++ ret = 0; ++ goto err_disable_unlock_all; ++ } ++ ++ /* ++ * We're fully committed and can't fail. The PREPPED -> ENABLED ++ * transitions here are synchronized against sched_ext_free() through ++ * scx_tasks_lock. ++ */ ++ WRITE_ONCE(scx_switching_all, !(ops->flags & SCX_OPS_SWITCH_PARTIAL)); ++ ++ scx_task_iter_init(&sti); ++ while ((p = scx_task_iter_next_locked(&sti, false))) { ++ const struct sched_class *old_class = p->sched_class; ++ struct sched_enq_and_set_ctx ctx; ++ ++ sched_deq_and_put_task(p, DEQUEUE_SAVE | DEQUEUE_MOVE, &ctx); ++ ++ scx_set_task_state(p, SCX_TASK_READY); ++ __setscheduler_prio(p, p->prio); ++ check_class_changing(task_rq(p), p, old_class); ++ ++ sched_enq_and_set_task(&ctx); ++ ++ check_class_changed(task_rq(p), p, old_class, p->prio); ++ } ++ scx_task_iter_exit(&sti); ++ ++ spin_unlock_irq(&scx_tasks_lock); ++ preempt_enable(); ++ cpus_read_unlock(); ++ percpu_up_write(&scx_fork_rwsem); ++ ++ /* see above ENABLING transition for the explanation on exiting with 0 */ ++ if (!scx_ops_tryset_enable_state(SCX_OPS_ENABLED, SCX_OPS_ENABLING)) { ++ WARN_ON_ONCE(atomic_read(&scx_exit_kind) == SCX_EXIT_NONE); ++ ret = 0; ++ goto err_disable; ++ } ++ ++ if (!(ops->flags & SCX_OPS_SWITCH_PARTIAL)) ++ static_branch_enable(&__scx_switched_all); ++ ++ kobject_uevent(scx_root_kobj, KOBJ_ADD); ++ mutex_unlock(&scx_ops_enable_mutex); ++ ++ return 0; ++ ++err_del: ++ kobject_del(scx_root_kobj); ++err: ++ kobject_put(scx_root_kobj); ++ scx_root_kobj = NULL; ++ if (scx_exit_info) { ++ free_exit_info(scx_exit_info); ++ scx_exit_info = NULL; ++ } ++err_unlock: ++ mutex_unlock(&scx_ops_enable_mutex); ++ return ret; ++ ++err_disable_unlock_all: ++ percpu_up_write(&scx_fork_rwsem); ++err_disable_unlock_cpus: ++ cpus_read_unlock(); ++err_disable: ++ mutex_unlock(&scx_ops_enable_mutex); ++ /* must be fully disabled before returning */ ++ scx_ops_disable(SCX_EXIT_ERROR); ++ kthread_flush_work(&scx_ops_disable_work); ++ return ret; ++} ++ ++ ++/******************************************************************************** ++ * bpf_struct_ops plumbing. ++ */ ++#include ++#include ++#include ++ ++extern struct btf *btf_vmlinux; ++static const struct btf_type *task_struct_type; ++static u32 task_struct_type_id; ++ ++static bool set_arg_maybe_null(const char *op, int arg_n, int off, int size, ++ enum bpf_access_type type, ++ const struct bpf_prog *prog, ++ struct bpf_insn_access_aux *info) ++{ ++ struct btf *btf = bpf_get_btf_vmlinux(); ++ const struct bpf_struct_ops_desc *st_ops_desc; ++ const struct btf_member *member; ++ const struct btf_type *t; ++ u32 btf_id, member_idx; ++ const char *mname; ++ ++ /* struct_ops op args are all sequential, 64-bit numbers */ ++ if (off != arg_n * sizeof(__u64)) ++ return false; ++ ++ /* btf_id should be the type id of struct sched_ext_ops */ ++ btf_id = prog->aux->attach_btf_id; ++ st_ops_desc = bpf_struct_ops_find(btf, btf_id); ++ if (!st_ops_desc) ++ return false; ++ ++ /* BTF type of struct sched_ext_ops */ ++ t = st_ops_desc->type; ++ ++ member_idx = prog->expected_attach_type; ++ if (member_idx >= btf_type_vlen(t)) ++ return false; ++ ++ /* ++ * Get the member name of this struct_ops program, which corresponds to ++ * a field in struct sched_ext_ops. For example, the member name of the ++ * dispatch struct_ops program (callback) is "dispatch". ++ */ ++ member = &btf_type_member(t)[member_idx]; ++ mname = btf_name_by_offset(btf_vmlinux, member->name_off); ++ ++ if (!strcmp(mname, op)) { ++ /* ++ * The value is a pointer to a type (struct task_struct) given ++ * by a BTF ID (PTR_TO_BTF_ID). It is trusted (PTR_TRUSTED), ++ * however, can be a NULL (PTR_MAYBE_NULL). The BPF program ++ * should check the pointer to make sure it is not NULL before ++ * using it, or the verifier will reject the program. ++ * ++ * Longer term, this is something that should be addressed by ++ * BTF, and be fully contained within the verifier. ++ */ ++ info->reg_type = PTR_MAYBE_NULL | PTR_TO_BTF_ID | PTR_TRUSTED; ++ info->btf = btf_vmlinux; ++ info->btf_id = task_struct_type_id; ++ ++ return true; ++ } ++ ++ return false; ++} ++ ++static bool bpf_scx_is_valid_access(int off, int size, ++ enum bpf_access_type type, ++ const struct bpf_prog *prog, ++ struct bpf_insn_access_aux *info) ++{ ++ if (type != BPF_READ) ++ return false; ++ if (set_arg_maybe_null("dispatch", 1, off, size, type, prog, info) || ++ set_arg_maybe_null("yield", 1, off, size, type, prog, info)) ++ return true; ++ if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS) ++ return false; ++ if (off % size != 0) ++ return false; ++ ++ return btf_ctx_access(off, size, type, prog, info); ++} ++ ++static int bpf_scx_btf_struct_access(struct bpf_verifier_log *log, ++ const struct bpf_reg_state *reg, int off, ++ int size) ++{ ++ const struct btf_type *t; ++ ++ t = btf_type_by_id(reg->btf, reg->btf_id); ++ if (t == task_struct_type) { ++ if (off >= offsetof(struct task_struct, scx.slice) && ++ off + size <= offsetofend(struct task_struct, scx.slice)) ++ return SCALAR_VALUE; ++ if (off >= offsetof(struct task_struct, scx.dsq_vtime) && ++ off + size <= offsetofend(struct task_struct, scx.dsq_vtime)) ++ return SCALAR_VALUE; ++ if (off >= offsetof(struct task_struct, scx.disallow) && ++ off + size <= offsetofend(struct task_struct, scx.disallow)) ++ return SCALAR_VALUE; ++ } ++ ++ return -EACCES; ++} ++ ++static const struct bpf_func_proto * ++bpf_scx_get_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) ++{ ++ switch (func_id) { ++ case BPF_FUNC_task_storage_get: ++ return &bpf_task_storage_get_proto; ++ case BPF_FUNC_task_storage_delete: ++ return &bpf_task_storage_delete_proto; ++ default: ++ return bpf_base_func_proto(func_id, prog); ++ } ++} ++ ++static const struct bpf_verifier_ops bpf_scx_verifier_ops = { ++ .get_func_proto = bpf_scx_get_func_proto, ++ .is_valid_access = bpf_scx_is_valid_access, ++ .btf_struct_access = bpf_scx_btf_struct_access, ++}; ++ ++static int bpf_scx_init_member(const struct btf_type *t, ++ const struct btf_member *member, ++ void *kdata, const void *udata) ++{ ++ const struct sched_ext_ops *uops = udata; ++ struct sched_ext_ops *ops = kdata; ++ u32 moff = __btf_member_bit_offset(t, member) / 8; ++ int ret; ++ ++ switch (moff) { ++ case offsetof(struct sched_ext_ops, dispatch_max_batch): ++ if (*(u32 *)(udata + moff) > INT_MAX) ++ return -E2BIG; ++ ops->dispatch_max_batch = *(u32 *)(udata + moff); ++ return 1; ++ case offsetof(struct sched_ext_ops, flags): ++ if (*(u64 *)(udata + moff) & ~SCX_OPS_ALL_FLAGS) ++ return -EINVAL; ++ ops->flags = *(u64 *)(udata + moff); ++ return 1; ++ case offsetof(struct sched_ext_ops, name): ++ ret = bpf_obj_name_cpy(ops->name, uops->name, ++ sizeof(ops->name)); ++ if (ret < 0) ++ return ret; ++ if (ret == 0) ++ return -EINVAL; ++ return 1; ++ case offsetof(struct sched_ext_ops, timeout_ms): ++ if (msecs_to_jiffies(*(u32 *)(udata + moff)) > ++ SCX_WATCHDOG_MAX_TIMEOUT) ++ return -E2BIG; ++ ops->timeout_ms = *(u32 *)(udata + moff); ++ return 1; ++ case offsetof(struct sched_ext_ops, exit_dump_len): ++ ops->exit_dump_len = ++ *(u32 *)(udata + moff) ?: SCX_EXIT_DUMP_DFL_LEN; ++ return 1; ++ case offsetof(struct sched_ext_ops, hotplug_seq): ++ ops->hotplug_seq = *(u64 *)(udata + moff); ++ return 1; ++ } ++ ++ return 0; ++} ++ ++static int bpf_scx_check_member(const struct btf_type *t, ++ const struct btf_member *member, ++ const struct bpf_prog *prog) ++{ ++ u32 moff = __btf_member_bit_offset(t, member) / 8; ++ ++ switch (moff) { ++ case offsetof(struct sched_ext_ops, init_task): ++ case offsetof(struct sched_ext_ops, cpu_online): ++ case offsetof(struct sched_ext_ops, cpu_offline): ++ case offsetof(struct sched_ext_ops, init): ++ case offsetof(struct sched_ext_ops, exit): ++ break; ++ default: ++ if (prog->sleepable) ++ return -EINVAL; ++ } ++ ++ return 0; ++} ++ ++static int bpf_scx_reg(void *kdata) ++{ ++ return scx_ops_enable(kdata); ++} ++ ++static void bpf_scx_unreg(void *kdata) ++{ ++ scx_ops_disable(SCX_EXIT_UNREG); ++ kthread_flush_work(&scx_ops_disable_work); ++} ++ ++static int bpf_scx_init(struct btf *btf) ++{ ++ s32 type_id; ++ ++ type_id = btf_find_by_name_kind(btf, "task_struct", BTF_KIND_STRUCT); ++ if (type_id < 0) ++ return -EINVAL; ++ task_struct_type = btf_type_by_id(btf, type_id); ++ task_struct_type_id = type_id; ++ ++ return 0; ++} ++ ++static int bpf_scx_update(void *kdata, void *old_kdata) ++{ ++ /* ++ * sched_ext does not support updating the actively-loaded BPF ++ * scheduler, as registering a BPF scheduler can always fail if the ++ * scheduler returns an error code for e.g. ops.init(), ops.init_task(), ++ * etc. Similarly, we can always race with unregistration happening ++ * elsewhere, such as with sysrq. ++ */ ++ return -EOPNOTSUPP; ++} ++ ++static int bpf_scx_validate(void *kdata) ++{ ++ return 0; ++} ++ ++static s32 select_cpu_stub(struct task_struct *p, s32 prev_cpu, u64 wake_flags) { return -EINVAL; } ++static void enqueue_stub(struct task_struct *p, u64 enq_flags) {} ++static void dequeue_stub(struct task_struct *p, u64 enq_flags) {} ++static void dispatch_stub(s32 prev_cpu, struct task_struct *p) {} ++static void runnable_stub(struct task_struct *p, u64 enq_flags) {} ++static void running_stub(struct task_struct *p) {} ++static void stopping_stub(struct task_struct *p, bool runnable) {} ++static void quiescent_stub(struct task_struct *p, u64 deq_flags) {} ++static bool yield_stub(struct task_struct *from, struct task_struct *to) { return false; } ++static bool core_sched_before_stub(struct task_struct *a, struct task_struct *b) { return false; } ++static void set_weight_stub(struct task_struct *p, u32 weight) {} ++static void set_cpumask_stub(struct task_struct *p, const struct cpumask *mask) {} ++static void update_idle_stub(s32 cpu, bool idle) {} ++static void cpu_acquire_stub(s32 cpu, struct scx_cpu_acquire_args *args) {} ++static void cpu_release_stub(s32 cpu, struct scx_cpu_release_args *args) {} ++static s32 init_task_stub(struct task_struct *p, struct scx_init_task_args *args) { return -EINVAL; } ++static void exit_task_stub(struct task_struct *p, struct scx_exit_task_args *args) {} ++static void enable_stub(struct task_struct *p) {} ++static void disable_stub(struct task_struct *p) {} ++static void cpu_online_stub(s32 cpu) {} ++static void cpu_offline_stub(s32 cpu) {} ++static s32 init_stub(void) { return -EINVAL; } ++static void exit_stub(struct scx_exit_info *info) {} ++ ++static struct sched_ext_ops __bpf_ops_sched_ext_ops = { ++ .select_cpu = select_cpu_stub, ++ .enqueue = enqueue_stub, ++ .dequeue = dequeue_stub, ++ .dispatch = dispatch_stub, ++ .runnable = runnable_stub, ++ .running = running_stub, ++ .stopping = stopping_stub, ++ .quiescent = quiescent_stub, ++ .yield = yield_stub, ++ .core_sched_before = core_sched_before_stub, ++ .set_weight = set_weight_stub, ++ .set_cpumask = set_cpumask_stub, ++ .update_idle = update_idle_stub, ++ .cpu_acquire = cpu_acquire_stub, ++ .cpu_release = cpu_release_stub, ++ .init_task = init_task_stub, ++ .exit_task = exit_task_stub, ++ .enable = enable_stub, ++ .disable = disable_stub, ++ .cpu_online = cpu_online_stub, ++ .cpu_offline = cpu_offline_stub, ++ .init = init_stub, ++ .exit = exit_stub, ++}; ++ ++static struct bpf_struct_ops bpf_sched_ext_ops = { ++ .verifier_ops = &bpf_scx_verifier_ops, ++ .reg = bpf_scx_reg, ++ .unreg = bpf_scx_unreg, ++ .check_member = bpf_scx_check_member, ++ .init_member = bpf_scx_init_member, ++ .init = bpf_scx_init, ++ .update = bpf_scx_update, ++ .validate = bpf_scx_validate, ++ .name = "sched_ext_ops", ++ .owner = THIS_MODULE, ++ .cfi_stubs = &__bpf_ops_sched_ext_ops ++}; ++ ++ ++/******************************************************************************** ++ * System integration and init. ++ */ ++ ++static void sysrq_handle_sched_ext_reset(u8 key) ++{ ++ if (scx_ops_helper) ++ scx_ops_disable(SCX_EXIT_SYSRQ); ++ else ++ pr_info("sched_ext: BPF scheduler not yet used\n"); ++} ++ ++static const struct sysrq_key_op sysrq_sched_ext_reset_op = { ++ .handler = sysrq_handle_sched_ext_reset, ++ .help_msg = "reset-sched-ext(S)", ++ .action_msg = "Disable sched_ext and revert all tasks to CFS", ++ .enable_mask = SYSRQ_ENABLE_RTNICE, ++}; ++ ++static void sysrq_handle_sched_ext_dump(u8 key) ++{ ++ struct scx_exit_info ei = { .kind = SCX_EXIT_NONE, .reason = "SysRq-D" }; ++ ++ if (scx_enabled()) ++ scx_dump_state(&ei, 0); ++} ++ ++static const struct sysrq_key_op sysrq_sched_ext_dump_op = { ++ .handler = sysrq_handle_sched_ext_dump, ++ .help_msg = "dump-sched-ext(D)", ++ .action_msg = "Trigger sched_ext debug dump", ++ .enable_mask = SYSRQ_ENABLE_RTNICE, ++}; ++ ++static bool can_skip_idle_kick(struct rq *rq) ++{ ++ lockdep_assert_rq_held(rq); ++ ++ /* ++ * We can skip idle kicking if @rq is going to go through at least one ++ * full SCX scheduling cycle before going idle. Just checking whether ++ * curr is not idle is insufficient because we could be racing ++ * balance_one() trying to pull the next task from a remote rq, which ++ * may fail, and @rq may become idle afterwards. ++ * ++ * The race window is small and we don't and can't guarantee that @rq is ++ * only kicked while idle anyway. Skip only when sure. ++ */ ++ return !is_idle_task(rq->curr) && !(rq->scx.flags & SCX_RQ_IN_BALANCE); ++} ++ ++static bool kick_one_cpu(s32 cpu, struct rq *this_rq, unsigned long *pseqs) ++{ ++ struct rq *rq = cpu_rq(cpu); ++ struct scx_rq *this_scx = &this_rq->scx; ++ bool should_wait = false; ++ unsigned long flags; ++ ++ raw_spin_rq_lock_irqsave(rq, flags); ++ ++ /* ++ * During CPU hotplug, a CPU may depend on kicking itself to make ++ * forward progress. Allow kicking self regardless of online state. ++ */ ++ if (cpu_online(cpu) || cpu == cpu_of(this_rq)) { ++ if (cpumask_test_cpu(cpu, this_scx->cpus_to_preempt)) { ++ if (rq->curr->sched_class == &ext_sched_class) ++ rq->curr->scx.slice = 0; ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_preempt); ++ } ++ ++ if (cpumask_test_cpu(cpu, this_scx->cpus_to_wait)) { ++ pseqs[cpu] = rq->scx.pnt_seq; ++ should_wait = true; ++ } ++ ++ resched_curr(rq); ++ } else { ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_preempt); ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_wait); ++ } ++ ++ raw_spin_rq_unlock_irqrestore(rq, flags); ++ ++ return should_wait; ++} ++ ++static void kick_one_cpu_if_idle(s32 cpu, struct rq *this_rq) ++{ ++ struct rq *rq = cpu_rq(cpu); ++ unsigned long flags; ++ ++ raw_spin_rq_lock_irqsave(rq, flags); ++ ++ if (!can_skip_idle_kick(rq) && ++ (cpu_online(cpu) || cpu == cpu_of(this_rq))) ++ resched_curr(rq); ++ ++ raw_spin_rq_unlock_irqrestore(rq, flags); ++} ++ ++static void kick_cpus_irq_workfn(struct irq_work *irq_work) ++{ ++ struct rq *this_rq = this_rq(); ++ struct scx_rq *this_scx = &this_rq->scx; ++ unsigned long *pseqs = this_cpu_ptr(scx_kick_cpus_pnt_seqs); ++ bool should_wait = false; ++ s32 cpu; ++ ++ for_each_cpu(cpu, this_scx->cpus_to_kick) { ++ should_wait |= kick_one_cpu(cpu, this_rq, pseqs); ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_kick); ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_kick_if_idle); ++ } ++ ++ for_each_cpu(cpu, this_scx->cpus_to_kick_if_idle) { ++ kick_one_cpu_if_idle(cpu, this_rq); ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_kick_if_idle); ++ } ++ ++ if (!should_wait) ++ return; ++ ++ for_each_cpu(cpu, this_scx->cpus_to_wait) { ++ unsigned long *wait_pnt_seq = &cpu_rq(cpu)->scx.pnt_seq; ++ ++ if (cpu != cpu_of(this_rq)) { ++ /* ++ * Pairs with smp_store_release() issued by this CPU in ++ * scx_next_task_picked() on the resched path. ++ * ++ * We busy-wait here to guarantee that no other task can ++ * be scheduled on our core before the target CPU has ++ * entered the resched path. ++ */ ++ while (smp_load_acquire(wait_pnt_seq) == pseqs[cpu]) ++ cpu_relax(); ++ } ++ ++ cpumask_clear_cpu(cpu, this_scx->cpus_to_wait); ++ } ++} ++ ++/** ++ * print_scx_info - print out sched_ext scheduler state ++ * @log_lvl: the log level to use when printing ++ * @p: target task ++ * ++ * If a sched_ext scheduler is enabled, print the name and state of the ++ * scheduler. If @p is on sched_ext, print further information about the task. ++ * ++ * This function can be safely called on any task as long as the task_struct ++ * itself is accessible. While safe, this function isn't synchronized and may ++ * print out mixups or garbages of limited length. ++ */ ++void print_scx_info(const char *log_lvl, struct task_struct *p) ++{ ++ enum scx_ops_enable_state state = scx_ops_enable_state(); ++ const char *all = READ_ONCE(scx_switching_all) ? "+all" : ""; ++ char runnable_at_buf[22] = "?"; ++ struct sched_class *class; ++ unsigned long runnable_at; ++ ++ if (state == SCX_OPS_DISABLED) ++ return; ++ ++ /* ++ * Carefully check if the task was running on sched_ext, and then ++ * carefully copy the time it's been runnable, and its state. ++ */ ++ if (copy_from_kernel_nofault(&class, &p->sched_class, sizeof(class)) || ++ class != &ext_sched_class) { ++ printk("%sSched_ext: %s (%s%s)", log_lvl, scx_ops.name, ++ scx_ops_enable_state_str[state], all); ++ return; ++ } ++ ++ if (!copy_from_kernel_nofault(&runnable_at, &p->scx.runnable_at, ++ sizeof(runnable_at))) ++ scnprintf(runnable_at_buf, sizeof(runnable_at_buf), "%+ldms", ++ jiffies_delta_msecs(runnable_at, jiffies)); ++ ++ /* print everything onto one line to conserve console space */ ++ printk("%sSched_ext: %s (%s%s), task: runnable_at=%s", ++ log_lvl, scx_ops.name, scx_ops_enable_state_str[state], all, ++ runnable_at_buf); ++} ++ ++static int scx_pm_handler(struct notifier_block *nb, unsigned long event, void *ptr) ++{ ++ /* ++ * SCX schedulers often have userspace components which are sometimes ++ * involved in critial scheduling paths. PM operations involve freezing ++ * userspace which can lead to scheduling misbehaviors including stalls. ++ * Let's bypass while PM operations are in progress. ++ */ ++ switch (event) { ++ case PM_HIBERNATION_PREPARE: ++ case PM_SUSPEND_PREPARE: ++ case PM_RESTORE_PREPARE: ++ scx_ops_bypass(true); ++ break; ++ case PM_POST_HIBERNATION: ++ case PM_POST_SUSPEND: ++ case PM_POST_RESTORE: ++ scx_ops_bypass(false); ++ break; ++ } ++ ++ return NOTIFY_OK; ++} ++ ++static struct notifier_block scx_pm_notifier = { ++ .notifier_call = scx_pm_handler, ++}; ++ ++void __init init_sched_ext_class(void) ++{ ++ s32 cpu, v; ++ ++ /* ++ * The following is to prevent the compiler from optimizing out the enum ++ * definitions so that BPF scheduler implementations can use them ++ * through the generated vmlinux.h. ++ */ ++ WRITE_ONCE(v, SCX_ENQ_WAKEUP | SCX_DEQ_SLEEP | SCX_KICK_PREEMPT); ++ ++ BUG_ON(rhashtable_init(&dsq_hash, &dsq_hash_params)); ++ init_dsq(&scx_dsq_global, SCX_DSQ_GLOBAL); ++#ifdef CONFIG_SMP ++ BUG_ON(!alloc_cpumask_var(&idle_masks.cpu, GFP_KERNEL)); ++ BUG_ON(!alloc_cpumask_var(&idle_masks.smt, GFP_KERNEL)); ++#endif ++ scx_kick_cpus_pnt_seqs = ++ __alloc_percpu(sizeof(scx_kick_cpus_pnt_seqs[0]) * nr_cpu_ids, ++ __alignof__(scx_kick_cpus_pnt_seqs[0])); ++ BUG_ON(!scx_kick_cpus_pnt_seqs); ++ ++ for_each_possible_cpu(cpu) { ++ struct rq *rq = cpu_rq(cpu); ++ ++ init_dsq(&rq->scx.local_dsq, SCX_DSQ_LOCAL); ++ INIT_LIST_HEAD(&rq->scx.runnable_list); ++ INIT_LIST_HEAD(&rq->scx.ddsp_deferred_locals); ++ ++ BUG_ON(!zalloc_cpumask_var(&rq->scx.cpus_to_kick, GFP_KERNEL)); ++ BUG_ON(!zalloc_cpumask_var(&rq->scx.cpus_to_kick_if_idle, GFP_KERNEL)); ++ BUG_ON(!zalloc_cpumask_var(&rq->scx.cpus_to_preempt, GFP_KERNEL)); ++ BUG_ON(!zalloc_cpumask_var(&rq->scx.cpus_to_wait, GFP_KERNEL)); ++ init_irq_work(&rq->scx.deferred_irq_work, deferred_irq_workfn); ++ init_irq_work(&rq->scx.kick_cpus_irq_work, kick_cpus_irq_workfn); ++ ++ if (cpu_online(cpu)) ++ cpu_rq(cpu)->scx.flags |= SCX_RQ_ONLINE; ++ } ++ ++ register_sysrq_key('S', &sysrq_sched_ext_reset_op); ++ register_sysrq_key('D', &sysrq_sched_ext_dump_op); ++ INIT_DELAYED_WORK(&scx_watchdog_work, scx_watchdog_workfn); ++} ++ ++ ++/******************************************************************************** ++ * Helpers that can be called from the BPF scheduler. ++ */ ++#include ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_create_dsq - Create a custom DSQ ++ * @dsq_id: DSQ to create ++ * @node: NUMA node to allocate from ++ * ++ * Create a custom DSQ identified by @dsq_id. Can be called from any sleepable ++ * scx callback, and any BPF_PROG_TYPE_SYSCALL prog. ++ */ ++__bpf_kfunc s32 scx_bpf_create_dsq(u64 dsq_id, s32 node) ++{ ++ if (unlikely(node >= (int)nr_node_ids || ++ (node < 0 && node != NUMA_NO_NODE))) ++ return -EINVAL; ++ return PTR_ERR_OR_ZERO(create_dsq(dsq_id, node)); ++} ++ ++__bpf_kfunc_end_defs(); ++ ++BTF_KFUNCS_START(scx_kfunc_ids_sleepable) ++BTF_ID_FLAGS(func, scx_bpf_create_dsq, KF_SLEEPABLE) ++BTF_KFUNCS_END(scx_kfunc_ids_sleepable) ++ ++static const struct btf_kfunc_id_set scx_kfunc_set_sleepable = { ++ .owner = THIS_MODULE, ++ .set = &scx_kfunc_ids_sleepable, ++}; ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_select_cpu_dfl - The default implementation of ops.select_cpu() ++ * @p: task_struct to select a CPU for ++ * @prev_cpu: CPU @p was on previously ++ * @wake_flags: %SCX_WAKE_* flags ++ * @is_idle: out parameter indicating whether the returned CPU is idle ++ * ++ * Can only be called from ops.select_cpu() if the built-in CPU selection is ++ * enabled - ops.update_idle() is missing or %SCX_OPS_KEEP_BUILTIN_IDLE is set. ++ * @p, @prev_cpu and @wake_flags match ops.select_cpu(). ++ * ++ * Returns the picked CPU with *@is_idle indicating whether the picked CPU is ++ * currently idle and thus a good candidate for direct dispatching. ++ */ ++__bpf_kfunc s32 scx_bpf_select_cpu_dfl(struct task_struct *p, s32 prev_cpu, ++ u64 wake_flags, bool *is_idle) ++{ ++ if (!scx_kf_allowed(SCX_KF_SELECT_CPU)) { ++ *is_idle = false; ++ return prev_cpu; ++ } ++#ifdef CONFIG_SMP ++ return scx_select_cpu_dfl(p, prev_cpu, wake_flags, is_idle); ++#else ++ *is_idle = false; ++ return prev_cpu; ++#endif ++} ++ ++__bpf_kfunc_end_defs(); ++ ++BTF_KFUNCS_START(scx_kfunc_ids_select_cpu) ++BTF_ID_FLAGS(func, scx_bpf_select_cpu_dfl, KF_RCU) ++BTF_KFUNCS_END(scx_kfunc_ids_select_cpu) ++ ++static const struct btf_kfunc_id_set scx_kfunc_set_select_cpu = { ++ .owner = THIS_MODULE, ++ .set = &scx_kfunc_ids_select_cpu, ++}; ++ ++static bool scx_dispatch_preamble(struct task_struct *p, u64 enq_flags) ++{ ++ if (!scx_kf_allowed(SCX_KF_ENQUEUE | SCX_KF_DISPATCH)) ++ return false; ++ ++ lockdep_assert_irqs_disabled(); ++ ++ if (unlikely(!p)) { ++ scx_ops_error("called with NULL task"); ++ return false; ++ } ++ ++ if (unlikely(enq_flags & __SCX_ENQ_INTERNAL_MASK)) { ++ scx_ops_error("invalid enq_flags 0x%llx", enq_flags); ++ return false; ++ } ++ ++ return true; ++} ++ ++static void scx_dispatch_commit(struct task_struct *p, u64 dsq_id, u64 enq_flags) ++{ ++ struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx); ++ struct task_struct *ddsp_task; ++ ++ ddsp_task = __this_cpu_read(direct_dispatch_task); ++ if (ddsp_task) { ++ mark_direct_dispatch(ddsp_task, p, dsq_id, enq_flags); ++ return; ++ } ++ ++ if (unlikely(dspc->cursor >= scx_dsp_max_batch)) { ++ scx_ops_error("dispatch buffer overflow"); ++ return; ++ } ++ ++ dspc->buf[dspc->cursor++] = (struct scx_dsp_buf_ent){ ++ .task = p, ++ .qseq = atomic_long_read(&p->scx.ops_state) & SCX_OPSS_QSEQ_MASK, ++ .dsq_id = dsq_id, ++ .enq_flags = enq_flags, ++ }; ++} ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_dispatch - Dispatch a task into the FIFO queue of a DSQ ++ * @p: task_struct to dispatch ++ * @dsq_id: DSQ to dispatch to ++ * @slice: duration @p can run for in nsecs ++ * @enq_flags: SCX_ENQ_* ++ * ++ * Dispatch @p into the FIFO queue of the DSQ identified by @dsq_id. It is safe ++ * to call this function spuriously. Can be called from ops.enqueue(), ++ * ops.select_cpu(), and ops.dispatch(). ++ * ++ * When called from ops.select_cpu() or ops.enqueue(), it's for direct dispatch ++ * and @p must match the task being enqueued. Also, %SCX_DSQ_LOCAL_ON can't be ++ * used to target the local DSQ of a CPU other than the enqueueing one. Use ++ * ops.select_cpu() to be on the target CPU in the first place. ++ * ++ * When called from ops.select_cpu(), @enq_flags and @dsp_id are stored, and @p ++ * will be directly dispatched to the corresponding dispatch queue after ++ * ops.select_cpu() returns. If @p is dispatched to SCX_DSQ_LOCAL, it will be ++ * dispatched to the local DSQ of the CPU returned by ops.select_cpu(). ++ * @enq_flags are OR'd with the enqueue flags on the enqueue path before the ++ * task is dispatched. ++ * ++ * When called from ops.dispatch(), there are no restrictions on @p or @dsq_id ++ * and this function can be called upto ops.dispatch_max_batch times to dispatch ++ * multiple tasks. scx_bpf_dispatch_nr_slots() returns the number of the ++ * remaining slots. scx_bpf_consume() flushes the batch and resets the counter. ++ * ++ * This function doesn't have any locking restrictions and may be called under ++ * BPF locks (in the future when BPF introduces more flexible locking). ++ * ++ * @p is allowed to run for @slice. The scheduling path is triggered on slice ++ * exhaustion. If zero, the current residual slice is maintained. If ++ * %SCX_SLICE_INF, @p never expires and the BPF scheduler must kick the CPU with ++ * scx_bpf_kick_cpu() to trigger scheduling. ++ */ ++__bpf_kfunc void scx_bpf_dispatch(struct task_struct *p, u64 dsq_id, u64 slice, ++ u64 enq_flags) ++{ ++ if (!scx_dispatch_preamble(p, enq_flags)) ++ return; ++ ++ if (slice) ++ p->scx.slice = slice; ++ else ++ p->scx.slice = p->scx.slice ?: 1; ++ ++ scx_dispatch_commit(p, dsq_id, enq_flags); ++} ++ ++/** ++ * scx_bpf_dispatch_vtime - Dispatch a task into the vtime priority queue of a DSQ ++ * @p: task_struct to dispatch ++ * @dsq_id: DSQ to dispatch to ++ * @slice: duration @p can run for in nsecs ++ * @vtime: @p's ordering inside the vtime-sorted queue of the target DSQ ++ * @enq_flags: SCX_ENQ_* ++ * ++ * Dispatch @p into the vtime priority queue of the DSQ identified by @dsq_id. ++ * Tasks queued into the priority queue are ordered by @vtime and always ++ * consumed after the tasks in the FIFO queue. All other aspects are identical ++ * to scx_bpf_dispatch(). ++ * ++ * @vtime ordering is according to time_before64() which considers wrapping. A ++ * numerically larger vtime may indicate an earlier position in the ordering and ++ * vice-versa. ++ */ ++__bpf_kfunc void scx_bpf_dispatch_vtime(struct task_struct *p, u64 dsq_id, ++ u64 slice, u64 vtime, u64 enq_flags) ++{ ++ if (!scx_dispatch_preamble(p, enq_flags)) ++ return; ++ ++ if (slice) ++ p->scx.slice = slice; ++ else ++ p->scx.slice = p->scx.slice ?: 1; ++ ++ p->scx.dsq_vtime = vtime; ++ ++ scx_dispatch_commit(p, dsq_id, enq_flags | SCX_ENQ_DSQ_PRIQ); ++} ++ ++__bpf_kfunc_end_defs(); ++ ++BTF_KFUNCS_START(scx_kfunc_ids_enqueue_dispatch) ++BTF_ID_FLAGS(func, scx_bpf_dispatch, KF_RCU) ++BTF_ID_FLAGS(func, scx_bpf_dispatch_vtime, KF_RCU) ++BTF_KFUNCS_END(scx_kfunc_ids_enqueue_dispatch) ++ ++static const struct btf_kfunc_id_set scx_kfunc_set_enqueue_dispatch = { ++ .owner = THIS_MODULE, ++ .set = &scx_kfunc_ids_enqueue_dispatch, ++}; ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_dispatch_nr_slots - Return the number of remaining dispatch slots ++ * ++ * Can only be called from ops.dispatch(). ++ */ ++__bpf_kfunc u32 scx_bpf_dispatch_nr_slots(void) ++{ ++ if (!scx_kf_allowed(SCX_KF_DISPATCH)) ++ return 0; ++ ++ return scx_dsp_max_batch - __this_cpu_read(scx_dsp_ctx->cursor); ++} ++ ++/** ++ * scx_bpf_dispatch_cancel - Cancel the latest dispatch ++ * ++ * Cancel the latest dispatch. Can be called multiple times to cancel further ++ * dispatches. Can only be called from ops.dispatch(). ++ */ ++__bpf_kfunc void scx_bpf_dispatch_cancel(void) ++{ ++ struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx); ++ ++ if (!scx_kf_allowed(SCX_KF_DISPATCH)) ++ return; ++ ++ if (dspc->cursor > 0) ++ dspc->cursor--; ++ else ++ scx_ops_error("dispatch buffer underflow"); ++} ++ ++/** ++ * scx_bpf_consume - Transfer a task from a DSQ to the current CPU's local DSQ ++ * @dsq_id: DSQ to consume ++ * ++ * Consume a task from the non-local DSQ identified by @dsq_id and transfer it ++ * to the current CPU's local DSQ for execution. Can only be called from ++ * ops.dispatch(). ++ * ++ * This function flushes the in-flight dispatches from scx_bpf_dispatch() before ++ * trying to consume the specified DSQ. It may also grab rq locks and thus can't ++ * be called under any BPF locks. ++ * ++ * Returns %true if a task has been consumed, %false if there isn't any task to ++ * consume. ++ */ ++__bpf_kfunc bool scx_bpf_consume(u64 dsq_id) ++{ ++ struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx); ++ struct scx_dispatch_q *dsq; ++ ++ if (!scx_kf_allowed(SCX_KF_DISPATCH)) ++ return false; ++ ++ flush_dispatch_buf(dspc->rq); ++ ++ dsq = find_non_local_dsq(dsq_id); ++ if (unlikely(!dsq)) { ++ scx_ops_error("invalid DSQ ID 0x%016llx", dsq_id); ++ return false; ++ } ++ ++ if (consume_dispatch_q(dspc->rq, dsq)) { ++ /* ++ * A successfully consumed task can be dequeued before it starts ++ * running while the CPU is trying to migrate other dispatched ++ * tasks. Bump nr_tasks to tell balance_scx() to retry on empty ++ * local DSQ. ++ */ ++ dspc->nr_tasks++; ++ return true; ++ } else { ++ return false; ++ } ++} ++ ++__bpf_kfunc_end_defs(); ++ ++BTF_KFUNCS_START(scx_kfunc_ids_dispatch) ++BTF_ID_FLAGS(func, scx_bpf_dispatch_nr_slots) ++BTF_ID_FLAGS(func, scx_bpf_dispatch_cancel) ++BTF_ID_FLAGS(func, scx_bpf_consume) ++BTF_KFUNCS_END(scx_kfunc_ids_dispatch) ++ ++static const struct btf_kfunc_id_set scx_kfunc_set_dispatch = { ++ .owner = THIS_MODULE, ++ .set = &scx_kfunc_ids_dispatch, ++}; ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_reenqueue_local - Re-enqueue tasks on a local DSQ ++ * ++ * Iterate over all of the tasks currently enqueued on the local DSQ of the ++ * caller's CPU, and re-enqueue them in the BPF scheduler. Returns the number of ++ * processed tasks. Can only be called from ops.cpu_release(). ++ */ ++__bpf_kfunc u32 scx_bpf_reenqueue_local(void) ++{ ++ LIST_HEAD(tasks); ++ u32 nr_enqueued = 0; ++ struct rq *rq; ++ struct task_struct *p, *n; ++ ++ if (!scx_kf_allowed(SCX_KF_CPU_RELEASE)) ++ return 0; ++ ++ rq = cpu_rq(smp_processor_id()); ++ lockdep_assert_rq_held(rq); ++ ++ /* ++ * The BPF scheduler may choose to dispatch tasks back to ++ * @rq->scx.local_dsq. Move all candidate tasks off to a private list ++ * first to avoid processing the same tasks repeatedly. ++ */ ++ list_for_each_entry_safe(p, n, &rq->scx.local_dsq.list, ++ scx.dsq_list.node) { ++ /* ++ * If @p is being migrated, @p's current CPU may not agree with ++ * its allowed CPUs and the migration_cpu_stop is about to ++ * deactivate and re-activate @p anyway. Skip re-enqueueing. ++ * ++ * While racing sched property changes may also dequeue and ++ * re-enqueue a migrating task while its current CPU and allowed ++ * CPUs disagree, they use %ENQUEUE_RESTORE which is bypassed to ++ * the current local DSQ for running tasks and thus are not ++ * visible to the BPF scheduler. ++ */ ++ if (p->migration_pending) ++ continue; ++ ++ dispatch_dequeue(rq, p); ++ list_add_tail(&p->scx.dsq_list.node, &tasks); ++ } ++ ++ list_for_each_entry_safe(p, n, &tasks, scx.dsq_list.node) { ++ list_del_init(&p->scx.dsq_list.node); ++ do_enqueue_task(rq, p, SCX_ENQ_REENQ, -1); ++ nr_enqueued++; ++ } ++ ++ return nr_enqueued; ++} ++ ++__bpf_kfunc_end_defs(); ++ ++BTF_KFUNCS_START(scx_kfunc_ids_cpu_release) ++BTF_ID_FLAGS(func, scx_bpf_reenqueue_local) ++BTF_KFUNCS_END(scx_kfunc_ids_cpu_release) ++ ++static const struct btf_kfunc_id_set scx_kfunc_set_cpu_release = { ++ .owner = THIS_MODULE, ++ .set = &scx_kfunc_ids_cpu_release, ++}; ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_kick_cpu - Trigger reschedule on a CPU ++ * @cpu: cpu to kick ++ * @flags: %SCX_KICK_* flags ++ * ++ * Kick @cpu into rescheduling. This can be used to wake up an idle CPU or ++ * trigger rescheduling on a busy CPU. This can be called from any online ++ * scx_ops operation and the actual kicking is performed asynchronously through ++ * an irq work. ++ */ ++__bpf_kfunc void scx_bpf_kick_cpu(s32 cpu, u64 flags) ++{ ++ struct rq *this_rq; ++ unsigned long irq_flags; ++ ++ if (!ops_cpu_valid(cpu, NULL)) ++ return; ++ ++ /* ++ * While bypassing for PM ops, IRQ handling may not be online which can ++ * lead to irq_work_queue() malfunction such as infinite busy wait for ++ * IRQ status update. Suppress kicking. ++ */ ++ if (scx_ops_bypassing()) ++ return; ++ ++ local_irq_save(irq_flags); ++ ++ this_rq = this_rq(); ++ ++ /* ++ * Actual kicking is bounced to kick_cpus_irq_workfn() to avoid nesting ++ * rq locks. We can probably be smarter and avoid bouncing if called ++ * from ops which don't hold a rq lock. ++ */ ++ if (flags & SCX_KICK_IDLE) { ++ struct rq *target_rq = cpu_rq(cpu); ++ ++ if (unlikely(flags & (SCX_KICK_PREEMPT | SCX_KICK_WAIT))) ++ scx_ops_error("PREEMPT/WAIT cannot be used with SCX_KICK_IDLE"); ++ ++ if (raw_spin_rq_trylock(target_rq)) { ++ if (can_skip_idle_kick(target_rq)) { ++ raw_spin_rq_unlock(target_rq); ++ goto out; ++ } ++ raw_spin_rq_unlock(target_rq); ++ } ++ cpumask_set_cpu(cpu, this_rq->scx.cpus_to_kick_if_idle); ++ } else { ++ cpumask_set_cpu(cpu, this_rq->scx.cpus_to_kick); ++ ++ if (flags & SCX_KICK_PREEMPT) ++ cpumask_set_cpu(cpu, this_rq->scx.cpus_to_preempt); ++ if (flags & SCX_KICK_WAIT) ++ cpumask_set_cpu(cpu, this_rq->scx.cpus_to_wait); ++ } ++ ++ irq_work_queue(&this_rq->scx.kick_cpus_irq_work); ++out: ++ local_irq_restore(irq_flags); ++} ++ ++/** ++ * scx_bpf_dsq_nr_queued - Return the number of queued tasks ++ * @dsq_id: id of the DSQ ++ * ++ * Return the number of tasks in the DSQ matching @dsq_id. If not found, ++ * -%ENOENT is returned. ++ */ ++__bpf_kfunc s32 scx_bpf_dsq_nr_queued(u64 dsq_id) ++{ ++ struct scx_dispatch_q *dsq; ++ s32 ret; ++ ++ preempt_disable(); ++ ++ if (dsq_id == SCX_DSQ_LOCAL) { ++ ret = READ_ONCE(this_rq()->scx.local_dsq.nr); ++ goto out; ++ } else if ((dsq_id & SCX_DSQ_LOCAL_ON) == SCX_DSQ_LOCAL_ON) { ++ s32 cpu = dsq_id & SCX_DSQ_LOCAL_CPU_MASK; ++ ++ if (ops_cpu_valid(cpu, NULL)) { ++ ret = READ_ONCE(cpu_rq(cpu)->scx.local_dsq.nr); ++ goto out; ++ } ++ } else { ++ dsq = find_non_local_dsq(dsq_id); ++ if (dsq) { ++ ret = READ_ONCE(dsq->nr); ++ goto out; ++ } ++ } ++ ret = -ENOENT; ++out: ++ preempt_enable(); ++ return ret; ++} ++ ++/** ++ * scx_bpf_destroy_dsq - Destroy a custom DSQ ++ * @dsq_id: DSQ to destroy ++ * ++ * Destroy the custom DSQ identified by @dsq_id. Only DSQs created with ++ * scx_bpf_create_dsq() can be destroyed. The caller must ensure that the DSQ is ++ * empty and no further tasks are dispatched to it. Ignored if called on a DSQ ++ * which doesn't exist. Can be called from any online scx_ops operations. ++ */ ++__bpf_kfunc void scx_bpf_destroy_dsq(u64 dsq_id) ++{ ++ destroy_dsq(dsq_id); ++} ++ ++/** ++ * bpf_iter_scx_dsq_new - Create a DSQ iterator ++ * @it: iterator to initialize ++ * @dsq_id: DSQ to iterate ++ * @flags: %SCX_DSQ_ITER_* ++ * ++ * Initialize BPF iterator @it which can be used with bpf_for_each() to walk ++ * tasks in the DSQ specified by @dsq_id. Iteration using @it only includes ++ * tasks which are already queued when this function is invoked. ++ */ ++__bpf_kfunc int bpf_iter_scx_dsq_new(struct bpf_iter_scx_dsq *it, u64 dsq_id, ++ u64 flags) ++{ ++ struct bpf_iter_scx_dsq_kern *kit = (void *)it; ++ ++ BUILD_BUG_ON(sizeof(struct bpf_iter_scx_dsq_kern) > ++ sizeof(struct bpf_iter_scx_dsq)); ++ BUILD_BUG_ON(__alignof__(struct bpf_iter_scx_dsq_kern) != ++ __alignof__(struct bpf_iter_scx_dsq)); ++ ++ if (flags & ~__SCX_DSQ_ITER_ALL_FLAGS) ++ return -EINVAL; ++ ++ kit->dsq = find_non_local_dsq(dsq_id); ++ if (!kit->dsq) ++ return -ENOENT; ++ ++ INIT_LIST_HEAD(&kit->cursor.node); ++ kit->cursor.is_bpf_iter_cursor = true; ++ kit->dsq_seq = READ_ONCE(kit->dsq->seq); ++ kit->flags = flags; ++ ++ return 0; ++} ++ ++/** ++ * bpf_iter_scx_dsq_next - Progress a DSQ iterator ++ * @it: iterator to progress ++ * ++ * Return the next task. See bpf_iter_scx_dsq_new(). ++ */ ++__bpf_kfunc struct task_struct *bpf_iter_scx_dsq_next(struct bpf_iter_scx_dsq *it) ++{ ++ struct bpf_iter_scx_dsq_kern *kit = (void *)it; ++ bool rev = kit->flags & SCX_DSQ_ITER_REV; ++ struct task_struct *p; ++ unsigned long flags; ++ ++ if (!kit->dsq) ++ return NULL; ++ ++ raw_spin_lock_irqsave(&kit->dsq->lock, flags); ++ ++ if (list_empty(&kit->cursor.node)) ++ p = NULL; ++ else ++ p = container_of(&kit->cursor, struct task_struct, scx.dsq_list); ++ ++ /* ++ * Only tasks which were queued before the iteration started are ++ * visible. This bounds BPF iterations and guarantees that vtime never ++ * jumps in the other direction while iterating. ++ */ ++ do { ++ p = nldsq_next_task(kit->dsq, p, rev); ++ } while (p && unlikely(u32_before(kit->dsq_seq, p->scx.dsq_seq))); ++ ++ if (p) { ++ if (rev) ++ list_move_tail(&kit->cursor.node, &p->scx.dsq_list.node); ++ else ++ list_move(&kit->cursor.node, &p->scx.dsq_list.node); ++ } else { ++ list_del_init(&kit->cursor.node); ++ } ++ ++ raw_spin_unlock_irqrestore(&kit->dsq->lock, flags); ++ ++ return p; ++} ++ ++/** ++ * bpf_iter_scx_dsq_destroy - Destroy a DSQ iterator ++ * @it: iterator to destroy ++ * ++ * Undo scx_iter_scx_dsq_new(). ++ */ ++__bpf_kfunc void bpf_iter_scx_dsq_destroy(struct bpf_iter_scx_dsq *it) ++{ ++ struct bpf_iter_scx_dsq_kern *kit = (void *)it; ++ ++ if (!kit->dsq) ++ return; ++ ++ if (!list_empty(&kit->cursor.node)) { ++ unsigned long flags; ++ ++ raw_spin_lock_irqsave(&kit->dsq->lock, flags); ++ list_del_init(&kit->cursor.node); ++ raw_spin_unlock_irqrestore(&kit->dsq->lock, flags); ++ } ++ kit->dsq = NULL; ++} ++ ++__bpf_kfunc_end_defs(); ++ ++static s32 __bstr_format(u64 *data_buf, char *line_buf, size_t line_size, ++ char *fmt, unsigned long long *data, u32 data__sz) ++{ ++ struct bpf_bprintf_data bprintf_data = { .get_bin_args = true }; ++ s32 ret; ++ ++ if (data__sz % 8 || data__sz > MAX_BPRINTF_VARARGS * 8 || ++ (data__sz && !data)) { ++ scx_ops_error("invalid data=%p and data__sz=%u", ++ (void *)data, data__sz); ++ return -EINVAL; ++ } ++ ++ ret = copy_from_kernel_nofault(data_buf, data, data__sz); ++ if (ret < 0) { ++ scx_ops_error("failed to read data fields (%d)", ret); ++ return ret; ++ } ++ ++ ret = bpf_bprintf_prepare(fmt, UINT_MAX, data_buf, data__sz / 8, ++ &bprintf_data); ++ if (ret < 0) { ++ scx_ops_error("format preparation failed (%d)", ret); ++ return ret; ++ } ++ ++ ret = bstr_printf(line_buf, line_size, fmt, ++ bprintf_data.bin_args); ++ bpf_bprintf_cleanup(&bprintf_data); ++ if (ret < 0) { ++ scx_ops_error("(\"%s\", %p, %u) failed to format", ++ fmt, data, data__sz); ++ return ret; ++ } ++ ++ return ret; ++} ++ ++static s32 bstr_format(struct scx_bstr_buf *buf, ++ char *fmt, unsigned long long *data, u32 data__sz) ++{ ++ return __bstr_format(buf->data, buf->line, sizeof(buf->line), ++ fmt, data, data__sz); ++} ++ ++__bpf_kfunc_start_defs(); ++ ++/** ++ * scx_bpf_exit_bstr - Gracefully exit the BPF scheduler. ++ * @exit_code: Exit value to pass to user space via struct scx_exit_info. ++ * @fmt: error message format string ++ * @data: format string parameters packaged using ___bpf_fill() macro ++ * @data__sz: @data len, must end in '__sz' for the verifier ++ * ++ * Indicate that the BPF scheduler wants to exit gracefully, and initiate ops ++ * disabling. ++ */ ++__bpf_kfunc void scx_bpf_exit_bstr(s64 exit_code, char *fmt, ++ unsigned long long *data, u32 data__sz) ++{ ++ unsigned long flags; ++ ++ raw_spin_lock_irqsave(&scx_exit_bstr_buf_lock, flags); ++ if (bstr_format(&scx_exit_bstr_buf, fmt, data, data__sz) >= 0) ++ scx_ops_exit_kind(SCX_EXIT_UNREG_BPF, exit_code, "%s", ++ scx_exit_bstr_buf.line); ++ raw_spin_unlock_irqrestore(&scx_exit_bstr_buf_lock, flags); ++} ++ ++/** ++ * scx_bpf_error_bstr - Indicate fatal error ++ * @fmt: error message format string ++ * @data: format string parameters packaged using ___bpf_fill() macro ++ * @data__sz: @data len, must end in '__sz' for the verifier ++ * ++ * Indicate that the BPF scheduler encountered a fatal error and initiate ops ++ * disabling. ++ */ ++__bpf_kfunc void scx_bpf_error_bstr(char *fmt, unsigned long long *data, ++ u32 data__sz) ++{ ++ unsigned long flags; ++ ++ raw_spin_lock_irqsave(&scx_exit_bstr_buf_lock, flags); ++ if (bstr_format(&scx_exit_bstr_buf, fmt, data, data__sz) >= 0) ++ scx_ops_exit_kind(SCX_EXIT_ERROR_BPF, 0, "%s", ++ scx_exit_bstr_buf.line); ++ raw_spin_unlock_irqrestore(&scx_exit_bstr_buf_lock, flags); ++} ++ ++/** ++ * scx_bpf_dump - Generate extra debug dump specific to the BPF scheduler ++ * @fmt: format string ++ * @data: format string parameters packaged using ___bpf_fill() macro ++ * @data__sz: @data len, must end in '__sz' for the verifier ++ * ++ * To be called through scx_bpf_dump() helper from ops.dump(), dump_cpu() and ++ * dump_task() to generate extra debug dump specific to the BPF scheduler. ++ * ++ * The extra dump may be multiple lines. A single line may be split over ++ * multiple calls. The last line is automatically terminated. ++ */ ++__bpf_kfunc void scx_bpf_dump_bstr(char *fmt, unsigned long long *data, ++ u32 data__sz) ++{ ++ struct scx_dump_data *dd = &scx_dump_data; ++ struct scx_bstr_buf *buf = &dd->buf; ++ s32 ret; ++ ++ if (raw_smp_processor_id() != dd->cpu) { ++ scx_ops_error("scx_bpf_dump() must only be called from ops.dump() and friends"); ++ return; ++ } ++ ++ /* append the formatted string to the line buf */ ++ ret = __bstr_format(buf->data, buf->line + dd->cursor, ++ sizeof(buf->line) - dd->cursor, fmt, data, data__sz); ++ if (ret < 0) { ++ dump_line(dd->s, "%s[!] (\"%s\", %p, %u) failed to format (%d)", ++ dd->prefix, fmt, data, data__sz, ret); ++ return; ++ } ++ ++ dd->cursor += ret; ++ dd->cursor = min_t(s32, dd->cursor, sizeof(buf->line)); ++ ++ if (!dd->cursor) ++ return; ++ ++ /* ++ * If the line buf overflowed or ends in a newline, flush it into the ++ * dump. This is to allow the caller to generate a single line over ++ * multiple calls. As ops_dump_flush() can also handle multiple lines in ++ * the line buf, the only case which can lead to an unexpected ++ * truncation is when the caller keeps generating newlines in the middle ++ * instead of the end consecutively. Don't do that. ++ */ ++ if (dd->cursor >= sizeof(buf->line) || buf->line[dd->cursor - 1] == '\n') ++ ops_dump_flush(); ++} ++ ++/** ++ * scx_bpf_cpuperf_cap - Query the maximum relative capacity of a CPU ++ * @cpu: CPU of interest ++ * ++ * Return the maximum relative capacity of @cpu in relation to the most ++ * performant CPU in the system. The return value is in the range [1, ++ * %SCX_CPUPERF_ONE]. See scx_bpf_cpuperf_cur(). ++ */ ++__bpf_kfunc u32 scx_bpf_cpuperf_cap(s32 cpu) ++{ ++ if (ops_cpu_valid(cpu, NULL)) ++ return arch_scale_cpu_capacity(cpu); ++ else ++ return SCX_CPUPERF_ONE; ++} ++ ++/** ++ * scx_bpf_cpuperf_cur - Query the current relative performance of a CPU ++ * @cpu: CPU of interest ++ * ++ * Return the current relative performance of @cpu in relation to its maximum. ++ * The return value is in the range [1, %SCX_CPUPERF_ONE]. ++ * ++ * The current performance level of a CPU in relation to the maximum performance ++ * available in the system can be calculated as follows: ++ * ++ * scx_bpf_cpuperf_cap() * scx_bpf_cpuperf_cur() / %SCX_CPUPERF_ONE ++ * ++ * The result is in the range [1, %SCX_CPUPERF_ONE]. ++ */ ++__bpf_kfunc u32 scx_bpf_cpuperf_cur(s32 cpu) ++{ ++ if (ops_cpu_valid(cpu, NULL)) ++ return arch_scale_freq_capacity(cpu); ++ else ++ return SCX_CPUPERF_ONE; ++} ++ ++/** ++ * scx_bpf_cpuperf_set - Set the relative performance target of a CPU ++ * @cpu: CPU of interest ++ * @perf: target performance level [0, %SCX_CPUPERF_ONE] ++ * @flags: %SCX_CPUPERF_* flags ++ * ++ * Set the target performance level of @cpu to @perf. @perf is in linear ++ * relative scale between 0 and %SCX_CPUPERF_ONE. This determines how the ++ * schedutil cpufreq governor chooses the target frequency. ++ * ++ * The actual performance level chosen, CPU grouping, and the overhead and ++ * latency of the operations are dependent on the hardware and cpufreq driver in ++ * use. Consult hardware and cpufreq documentation for more information. The ++ * current performance level can be monitored using scx_bpf_cpuperf_cur(). ++ */ ++__bpf_kfunc void scx_bpf_cpuperf_set(s32 cpu, u32 perf) ++{ ++ if (unlikely(perf > SCX_CPUPERF_ONE)) { ++ scx_ops_error("Invalid cpuperf target %u for CPU %d", perf, cpu); ++ return; ++ } ++ ++ if (ops_cpu_valid(cpu, NULL)) { ++ struct rq *rq = cpu_rq(cpu); ++ ++ rq->scx.cpuperf_target = perf; ++ ++ rcu_read_lock_sched_notrace(); ++ cpufreq_update_util(cpu_rq(cpu), 0); ++ rcu_read_unlock_sched_notrace(); ++ } ++} ++ ++/** ++ * scx_bpf_nr_cpu_ids - Return the number of possible CPU IDs ++ * ++ * All valid CPU IDs in the system are smaller than the returned value. ++ */ ++__bpf_kfunc u32 scx_bpf_nr_cpu_ids(void) ++{ ++ return nr_cpu_ids; ++} ++ ++/** ++ * scx_bpf_get_possible_cpumask - Get a referenced kptr to cpu_possible_mask ++ */ ++__bpf_kfunc const struct cpumask *scx_bpf_get_possible_cpumask(void) ++{ ++ return cpu_possible_mask; ++} ++ ++/** ++ * scx_bpf_get_online_cpumask - Get a referenced kptr to cpu_online_mask ++ */ ++__bpf_kfunc const struct cpumask *scx_bpf_get_online_cpumask(void) ++{ ++ return cpu_online_mask; ++} ++ ++/** ++ * scx_bpf_put_cpumask - Release a possible/online cpumask ++ * @cpumask: cpumask to release ++ */ ++__bpf_kfunc void scx_bpf_put_cpumask(const struct cpumask *cpumask) ++{ ++ /* ++ * Empty function body because we aren't actually acquiring or releasing ++ * a reference to a global cpumask, which is read-only in the caller and ++ * is never released. The acquire / release semantics here are just used ++ * to make the cpumask is a trusted pointer in the caller. ++ */ ++} ++ ++/** ++ * scx_bpf_get_idle_cpumask - Get a referenced kptr to the idle-tracking ++ * per-CPU cpumask. ++ * ++ * Returns NULL if idle tracking is not enabled, or running on a UP kernel. ++ */ ++__bpf_kfunc const struct cpumask *scx_bpf_get_idle_cpumask(void) ++{ ++ if (!static_branch_likely(&scx_builtin_idle_enabled)) { ++ scx_ops_error("built-in idle tracking is disabled"); ++ return cpu_none_mask; ++ } ++ ++#ifdef CONFIG_SMP ++ return idle_masks.cpu; ++#else ++ return cpu_none_mask; ++#endif ++} ++ ++/** ++ * scx_bpf_get_idle_smtmask - Get a referenced kptr to the idle-tracking, ++ * per-physical-core cpumask. Can be used to determine if an entire physical ++ * core is free. ++ * ++ * Returns NULL if idle tracking is not enabled, or running on a UP kernel. ++ */ ++__bpf_kfunc const struct cpumask *scx_bpf_get_idle_smtmask(void) ++{ ++ if (!static_branch_likely(&scx_builtin_idle_enabled)) { ++ scx_ops_error("built-in idle tracking is disabled"); ++ return cpu_none_mask; ++ } ++ ++#ifdef CONFIG_SMP ++ if (sched_smt_active()) ++ return idle_masks.smt; ++ else ++ return idle_masks.cpu; ++#else ++ return cpu_none_mask; ++#endif ++} ++ ++/** ++ * scx_bpf_put_idle_cpumask - Release a previously acquired referenced kptr to ++ * either the percpu, or SMT idle-tracking cpumask. ++ */ ++__bpf_kfunc void scx_bpf_put_idle_cpumask(const struct cpumask *idle_mask) ++{ ++ /* ++ * Empty function body because we aren't actually acquiring or releasing ++ * a reference to a global idle cpumask, which is read-only in the ++ * caller and is never released. The acquire / release semantics here ++ * are just used to make the cpumask a trusted pointer in the caller. ++ */ ++} ++ ++/** ++ * scx_bpf_test_and_clear_cpu_idle - Test and clear @cpu's idle state ++ * @cpu: cpu to test and clear idle for ++ * ++ * Returns %true if @cpu was idle and its idle state was successfully cleared. ++ * %false otherwise. ++ * ++ * Unavailable if ops.update_idle() is implemented and ++ * %SCX_OPS_KEEP_BUILTIN_IDLE is not set. ++ */ ++__bpf_kfunc bool scx_bpf_test_and_clear_cpu_idle(s32 cpu) ++{ ++ if (!static_branch_likely(&scx_builtin_idle_enabled)) { ++ scx_ops_error("built-in idle tracking is disabled"); ++ return false; ++ } ++ ++ if (ops_cpu_valid(cpu, NULL)) ++ return test_and_clear_cpu_idle(cpu); ++ else ++ return false; ++} ++ ++/** ++ * scx_bpf_pick_idle_cpu - Pick and claim an idle cpu ++ * @cpus_allowed: Allowed cpumask ++ * @flags: %SCX_PICK_IDLE_CPU_* flags ++ * ++ * Pick and claim an idle cpu in @cpus_allowed. Returns the picked idle cpu ++ * number on success. -%EBUSY if no matching cpu was found. ++ * ++ * Idle CPU tracking may race against CPU scheduling state transitions. For ++ * example, this function may return -%EBUSY as CPUs are transitioning into the ++ * idle state. If the caller then assumes that there will be dispatch events on ++ * the CPUs as they were all busy, the scheduler may end up stalling with CPUs ++ * idling while there are pending tasks. Use scx_bpf_pick_any_cpu() and ++ * scx_bpf_kick_cpu() to guarantee that there will be at least one dispatch ++ * event in the near future. ++ * ++ * Unavailable if ops.update_idle() is implemented and ++ * %SCX_OPS_KEEP_BUILTIN_IDLE is not set. ++ */ ++__bpf_kfunc s32 scx_bpf_pick_idle_cpu(const struct cpumask *cpus_allowed, ++ u64 flags) ++{ ++ if (!static_branch_likely(&scx_builtin_idle_enabled)) { ++ scx_ops_error("built-in idle tracking is disabled"); ++ return -EBUSY; ++ } ++ ++ return scx_pick_idle_cpu(cpus_allowed, flags); ++} ++ ++/** ++ * scx_bpf_pick_any_cpu - Pick and claim an idle cpu if available or pick any CPU ++ * @cpus_allowed: Allowed cpumask ++ * @flags: %SCX_PICK_IDLE_CPU_* flags ++ * ++ * Pick and claim an idle cpu in @cpus_allowed. If none is available, pick any ++ * CPU in @cpus_allowed. Guaranteed to succeed and returns the picked idle cpu ++ * number if @cpus_allowed is not empty. -%EBUSY is returned if @cpus_allowed is ++ * empty. ++ * ++ * If ops.update_idle() is implemented and %SCX_OPS_KEEP_BUILTIN_IDLE is not ++ * set, this function can't tell which CPUs are idle and will always pick any ++ * CPU. ++ */ ++__bpf_kfunc s32 scx_bpf_pick_any_cpu(const struct cpumask *cpus_allowed, ++ u64 flags) ++{ ++ s32 cpu; ++ ++ if (static_branch_likely(&scx_builtin_idle_enabled)) { ++ cpu = scx_pick_idle_cpu(cpus_allowed, flags); ++ if (cpu >= 0) ++ return cpu; ++ } ++ ++ cpu = cpumask_any_distribute(cpus_allowed); ++ if (cpu < nr_cpu_ids) ++ return cpu; ++ else ++ return -EBUSY; ++} ++ ++/** ++ * scx_bpf_task_running - Is task currently running? ++ * @p: task of interest ++ */ ++__bpf_kfunc bool scx_bpf_task_running(const struct task_struct *p) ++{ ++ return task_rq(p)->curr == p; ++} ++ ++/** ++ * scx_bpf_task_cpu - CPU a task is currently associated with ++ * @p: task of interest ++ */ ++__bpf_kfunc s32 scx_bpf_task_cpu(const struct task_struct *p) ++{ ++ return task_cpu(p); ++} ++ ++/** ++ * scx_bpf_cpu_rq - Fetch the rq of a CPU ++ * @cpu: CPU of the rq ++ */ ++__bpf_kfunc struct rq *scx_bpf_cpu_rq(s32 cpu) ++{ ++ if (!ops_cpu_valid(cpu, NULL)) ++ return NULL; ++ ++ return cpu_rq(cpu); ++} ++ ++__bpf_kfunc_end_defs(); ++ ++BTF_KFUNCS_START(scx_kfunc_ids_any) ++BTF_ID_FLAGS(func, scx_bpf_kick_cpu) ++BTF_ID_FLAGS(func, scx_bpf_dsq_nr_queued) ++BTF_ID_FLAGS(func, scx_bpf_destroy_dsq) ++BTF_ID_FLAGS(func, bpf_iter_scx_dsq_new, KF_ITER_NEW | KF_RCU_PROTECTED) ++BTF_ID_FLAGS(func, bpf_iter_scx_dsq_next, KF_ITER_NEXT | KF_RET_NULL) ++BTF_ID_FLAGS(func, bpf_iter_scx_dsq_destroy, KF_ITER_DESTROY) ++BTF_ID_FLAGS(func, scx_bpf_exit_bstr, KF_TRUSTED_ARGS) ++BTF_ID_FLAGS(func, scx_bpf_error_bstr, KF_TRUSTED_ARGS) ++BTF_ID_FLAGS(func, scx_bpf_dump_bstr, KF_TRUSTED_ARGS) ++BTF_ID_FLAGS(func, scx_bpf_cpuperf_cap) ++BTF_ID_FLAGS(func, scx_bpf_cpuperf_cur) ++BTF_ID_FLAGS(func, scx_bpf_cpuperf_set) ++BTF_ID_FLAGS(func, scx_bpf_nr_cpu_ids) ++BTF_ID_FLAGS(func, scx_bpf_get_possible_cpumask, KF_ACQUIRE) ++BTF_ID_FLAGS(func, scx_bpf_get_online_cpumask, KF_ACQUIRE) ++BTF_ID_FLAGS(func, scx_bpf_put_cpumask, KF_RELEASE) ++BTF_ID_FLAGS(func, scx_bpf_get_idle_cpumask, KF_ACQUIRE) ++BTF_ID_FLAGS(func, scx_bpf_get_idle_smtmask, KF_ACQUIRE) ++BTF_ID_FLAGS(func, scx_bpf_put_idle_cpumask, KF_RELEASE) ++BTF_ID_FLAGS(func, scx_bpf_test_and_clear_cpu_idle) ++BTF_ID_FLAGS(func, scx_bpf_pick_idle_cpu, KF_RCU) ++BTF_ID_FLAGS(func, scx_bpf_pick_any_cpu, KF_RCU) ++BTF_ID_FLAGS(func, scx_bpf_task_running, KF_RCU) ++BTF_ID_FLAGS(func, scx_bpf_task_cpu, KF_RCU) ++BTF_ID_FLAGS(func, scx_bpf_cpu_rq) ++BTF_KFUNCS_END(scx_kfunc_ids_any) ++ ++static const struct btf_kfunc_id_set scx_kfunc_set_any = { ++ .owner = THIS_MODULE, ++ .set = &scx_kfunc_ids_any, ++}; ++ ++static int __init scx_init(void) ++{ ++ int ret; ++ ++ /* ++ * kfunc registration can't be done from init_sched_ext_class() as ++ * register_btf_kfunc_id_set() needs most of the system to be up. ++ * ++ * Some kfuncs are context-sensitive and can only be called from ++ * specific SCX ops. They are grouped into BTF sets accordingly. ++ * Unfortunately, BPF currently doesn't have a way of enforcing such ++ * restrictions. Eventually, the verifier should be able to enforce ++ * them. For now, register them the same and make each kfunc explicitly ++ * check using scx_kf_allowed(). ++ */ ++ if ((ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, ++ &scx_kfunc_set_sleepable)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_SYSCALL, ++ &scx_kfunc_set_sleepable)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, ++ &scx_kfunc_set_select_cpu)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, ++ &scx_kfunc_set_enqueue_dispatch)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, ++ &scx_kfunc_set_dispatch)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, ++ &scx_kfunc_set_cpu_release)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, ++ &scx_kfunc_set_any)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_TRACING, ++ &scx_kfunc_set_any)) || ++ (ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_SYSCALL, ++ &scx_kfunc_set_any))) { ++ pr_err("sched_ext: Failed to register kfunc sets (%d)\n", ret); ++ return ret; ++ } ++ ++ ret = register_bpf_struct_ops(&bpf_sched_ext_ops, sched_ext_ops); ++ if (ret) { ++ pr_err("sched_ext: Failed to register struct_ops (%d)\n", ret); ++ return ret; ++ } ++ ++ ret = register_pm_notifier(&scx_pm_notifier); ++ if (ret) { ++ pr_err("sched_ext: Failed to register PM notifier (%d)\n", ret); ++ return ret; ++ } ++ ++ scx_kset = kset_create_and_add("sched_ext", &scx_uevent_ops, kernel_kobj); ++ if (!scx_kset) { ++ pr_err("sched_ext: Failed to create /sys/kernel/sched_ext\n"); ++ return -ENOMEM; ++ } ++ ++ ret = sysfs_create_group(&scx_kset->kobj, &scx_global_attr_group); ++ if (ret < 0) { ++ pr_err("sched_ext: Failed to add global attributes\n"); ++ return ret; ++ } ++ ++ return 0; ++} ++__initcall(scx_init); +diff --git a/kernel/sched/ext.h b/kernel/sched/ext.h +new file mode 100644 +index 000000000000..32d3a51f591a +--- /dev/null ++++ b/kernel/sched/ext.h +@@ -0,0 +1,69 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * BPF extensible scheduler class: Documentation/scheduler/sched-ext.rst ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#ifdef CONFIG_SCHED_CLASS_EXT ++ ++void scx_tick(struct rq *rq); ++void init_scx_entity(struct sched_ext_entity *scx); ++void scx_pre_fork(struct task_struct *p); ++int scx_fork(struct task_struct *p); ++void scx_post_fork(struct task_struct *p); ++void scx_cancel_fork(struct task_struct *p); ++bool scx_can_stop_tick(struct rq *rq); ++void scx_rq_activate(struct rq *rq); ++void scx_rq_deactivate(struct rq *rq); ++int scx_check_setscheduler(struct task_struct *p, int policy); ++bool task_should_scx(struct task_struct *p); ++void init_sched_ext_class(void); ++ ++static inline u32 scx_cpuperf_target(s32 cpu) ++{ ++ if (scx_enabled()) ++ return cpu_rq(cpu)->scx.cpuperf_target; ++ else ++ return 0; ++} ++ ++static inline bool task_on_scx(const struct task_struct *p) ++{ ++ return scx_enabled() && p->sched_class == &ext_sched_class; ++} ++ ++#ifdef CONFIG_SCHED_CORE ++bool scx_prio_less(const struct task_struct *a, const struct task_struct *b, ++ bool in_fi); ++#endif ++ ++#else /* CONFIG_SCHED_CLASS_EXT */ ++ ++static inline void scx_tick(struct rq *rq) {} ++static inline void scx_pre_fork(struct task_struct *p) {} ++static inline int scx_fork(struct task_struct *p) { return 0; } ++static inline void scx_post_fork(struct task_struct *p) {} ++static inline void scx_cancel_fork(struct task_struct *p) {} ++static inline u32 scx_cpuperf_target(s32 cpu) { return 0; } ++static inline bool scx_can_stop_tick(struct rq *rq) { return true; } ++static inline void scx_rq_activate(struct rq *rq) {} ++static inline void scx_rq_deactivate(struct rq *rq) {} ++static inline int scx_check_setscheduler(struct task_struct *p, int policy) { return 0; } ++static inline bool task_on_scx(const struct task_struct *p) { return false; } ++static inline void init_sched_ext_class(void) {} ++ ++#endif /* CONFIG_SCHED_CLASS_EXT */ ++ ++#if defined(CONFIG_SCHED_CLASS_EXT) && defined(CONFIG_SMP) ++void __scx_update_idle(struct rq *rq, bool idle); ++ ++static inline void scx_update_idle(struct rq *rq, bool idle) ++{ ++ if (scx_enabled()) ++ __scx_update_idle(rq, idle); ++} ++#else ++static inline void scx_update_idle(struct rq *rq, bool idle) {} ++#endif +diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c +index 1fee282d40aa..32f68ec1e528 100644 +--- a/kernel/sched/fair.c ++++ b/kernel/sched/fair.c +@@ -3848,7 +3848,8 @@ static void reweight_entity(struct cfs_rq *cfs_rq, struct sched_entity *se, + } + } + +-void reweight_task(struct task_struct *p, const struct load_weight *lw) ++static void reweight_task_fair(struct rq *rq, struct task_struct *p, ++ const struct load_weight *lw) + { + struct sched_entity *se = &p->se; + struct cfs_rq *cfs_rq = cfs_rq_of(se); +@@ -8403,7 +8404,7 @@ static void check_preempt_wakeup_fair(struct rq *rq, struct task_struct *p, int + * Batch and idle tasks do not preempt non-idle tasks (their preemption + * is driven by the tick): + */ +- if (unlikely(p->policy != SCHED_NORMAL) || !sched_feat(WAKEUP_PREEMPTION)) ++ if (unlikely(!normal_policy(p->policy)) || !sched_feat(WAKEUP_PREEMPTION)) + return; + + find_matching_se(&se, &pse); +@@ -9360,28 +9361,18 @@ static inline void update_blocked_load_status(struct rq *rq, bool has_blocked) { + + static bool __update_blocked_others(struct rq *rq, bool *done) + { +- const struct sched_class *curr_class; +- u64 now = rq_clock_pelt(rq); +- unsigned long hw_pressure; +- bool decayed; ++ bool updated; + + /* + * update_load_avg() can call cpufreq_update_util(). Make sure that RT, + * DL and IRQ signals have been updated before updating CFS. + */ +- curr_class = rq->curr->sched_class; +- +- hw_pressure = arch_scale_hw_pressure(cpu_of(rq)); +- +- decayed = update_rt_rq_load_avg(now, rq, curr_class == &rt_sched_class) | +- update_dl_rq_load_avg(now, rq, curr_class == &dl_sched_class) | +- update_hw_load_avg(now, rq, hw_pressure) | +- update_irq_load_avg(rq, 0); ++ updated = update_other_load_avgs(rq); + + if (others_have_blocked(rq)) + *done = false; + +- return decayed; ++ return updated; + } + + #ifdef CONFIG_FAIR_GROUP_SCHED +@@ -13220,6 +13211,7 @@ DEFINE_SCHED_CLASS(fair) = { + .task_tick = task_tick_fair, + .task_fork = task_fork_fair, + ++ .reweight_task = reweight_task_fair, + .prio_changed = prio_changed_fair, + .switched_from = switched_from_fair, + .switched_to = switched_to_fair, +diff --git a/kernel/sched/idle.c b/kernel/sched/idle.c +index 6135fbe83d68..3b6540cc436a 100644 +--- a/kernel/sched/idle.c ++++ b/kernel/sched/idle.c +@@ -458,11 +458,13 @@ static void wakeup_preempt_idle(struct rq *rq, struct task_struct *p, int flags) + + static void put_prev_task_idle(struct rq *rq, struct task_struct *prev) + { ++ scx_update_idle(rq, false); + } + + static void set_next_task_idle(struct rq *rq, struct task_struct *next, bool first) + { + update_idle_core(rq); ++ scx_update_idle(rq, true); + schedstat_inc(rq->sched_goidle); + } + +diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h +index 556466836cd5..fbcd2ddbf887 100644 +--- a/kernel/sched/sched.h ++++ b/kernel/sched/sched.h +@@ -187,9 +187,19 @@ static inline int idle_policy(int policy) + { + return policy == SCHED_IDLE; + } ++ ++static inline int normal_policy(int policy) ++{ ++#ifdef CONFIG_SCHED_CLASS_EXT ++ if (policy == SCHED_EXT) ++ return true; ++#endif ++ return policy == SCHED_NORMAL; ++} ++ + static inline int fair_policy(int policy) + { +- return policy == SCHED_NORMAL || policy == SCHED_BATCH; ++ return normal_policy(policy) || policy == SCHED_BATCH; + } + + static inline int rt_policy(int policy) +@@ -237,6 +247,24 @@ static inline void update_avg(u64 *avg, u64 sample) + #define shr_bound(val, shift) \ + (val >> min_t(typeof(shift), shift, BITS_PER_TYPE(typeof(val)) - 1)) + ++/* ++ * cgroup weight knobs should use the common MIN, DFL and MAX values which are ++ * 1, 100 and 10000 respectively. While it loses a bit of range on both ends, it ++ * maps pretty well onto the shares value used by scheduler and the round-trip ++ * conversions preserve the original value over the entire range. ++ */ ++static inline unsigned long sched_weight_from_cgroup(unsigned long cgrp_weight) ++{ ++ return DIV_ROUND_CLOSEST_ULL(cgrp_weight * 1024, CGROUP_WEIGHT_DFL); ++} ++ ++static inline unsigned long sched_weight_to_cgroup(unsigned long weight) ++{ ++ return clamp_t(unsigned long, ++ DIV_ROUND_CLOSEST_ULL(weight * CGROUP_WEIGHT_DFL, 1024), ++ CGROUP_WEIGHT_MIN, CGROUP_WEIGHT_MAX); ++} ++ + /* + * !! For sched_setattr_nocheck() (kernel) only !! + * +@@ -475,6 +503,11 @@ static inline int walk_tg_tree(tg_visitor down, tg_visitor up, void *data) + return walk_tg_tree_from(&root_task_group, down, up, data); + } + ++static inline struct task_group *css_tg(struct cgroup_subsys_state *css) ++{ ++ return css ? container_of(css, struct task_group, css) : NULL; ++} ++ + extern int tg_nop(struct task_group *tg, void *data); + + #ifdef CONFIG_FAIR_GROUP_SCHED +@@ -583,6 +616,12 @@ do { \ + # define u64_u32_load(var) u64_u32_load_copy(var, var##_copy) + # define u64_u32_store(var, val) u64_u32_store_copy(var, var##_copy, val) + ++struct rq; ++struct balance_callback { ++ struct balance_callback *next; ++ void (*func)(struct rq *rq); ++}; ++ + /* CFS-related fields in a runqueue */ + struct cfs_rq { + struct load_weight load; +@@ -691,6 +730,42 @@ struct cfs_rq { + #endif /* CONFIG_FAIR_GROUP_SCHED */ + }; + ++#ifdef CONFIG_SCHED_CLASS_EXT ++/* scx_rq->flags, protected by the rq lock */ ++enum scx_rq_flags { ++ /* ++ * A hotplugged CPU starts scheduling before rq_online_scx(). Track ++ * ops.cpu_on/offline() state so that ops.enqueue/dispatch() are called ++ * only while the BPF scheduler considers the CPU to be online. ++ */ ++ SCX_RQ_ONLINE = 1 << 0, ++ SCX_RQ_CAN_STOP_TICK = 1 << 1, ++ ++ SCX_RQ_IN_WAKEUP = 1 << 16, ++ SCX_RQ_IN_BALANCE = 1 << 17, ++}; ++ ++struct scx_rq { ++ struct scx_dispatch_q local_dsq; ++ struct list_head runnable_list; /* runnable tasks on this rq */ ++ struct list_head ddsp_deferred_locals; /* deferred ddsps from enq */ ++ unsigned long ops_qseq; ++ u64 extra_enq_flags; /* see move_task_to_local_dsq() */ ++ u32 nr_running; ++ u32 flags; ++ u32 cpuperf_target; /* [0, SCHED_CAPACITY_SCALE] */ ++ bool cpu_released; ++ cpumask_var_t cpus_to_kick; ++ cpumask_var_t cpus_to_kick_if_idle; ++ cpumask_var_t cpus_to_preempt; ++ cpumask_var_t cpus_to_wait; ++ unsigned long pnt_seq; ++ struct balance_callback deferred_bal_cb; ++ struct irq_work deferred_irq_work; ++ struct irq_work kick_cpus_irq_work; ++}; ++#endif /* CONFIG_SCHED_CLASS_EXT */ ++ + static inline int rt_bandwidth_enabled(void) + { + return sysctl_sched_rt_runtime >= 0; +@@ -988,12 +1063,6 @@ struct uclamp_rq { + DECLARE_STATIC_KEY_FALSE(sched_uclamp_used); + #endif /* CONFIG_UCLAMP_TASK */ + +-struct rq; +-struct balance_callback { +- struct balance_callback *next; +- void (*func)(struct rq *rq); +-}; +- + /* + * This is the main, per-CPU runqueue data structure. + * +@@ -1036,6 +1105,9 @@ struct rq { + struct cfs_rq cfs; + struct rt_rq rt; + struct dl_rq dl; ++#ifdef CONFIG_SCHED_CLASS_EXT ++ struct scx_rq scx; ++#endif + + #ifdef CONFIG_FAIR_GROUP_SCHED + /* list of leaf cfs_rq on this CPU: */ +@@ -2278,6 +2350,8 @@ struct sched_class { + void (*put_prev_task)(struct rq *rq, struct task_struct *p); + void (*set_next_task)(struct rq *rq, struct task_struct *p, bool first); + ++ void (*switch_class)(struct rq *rq, struct task_struct *next); ++ + #ifdef CONFIG_SMP + int (*balance)(struct rq *rq, struct task_struct *prev, struct rq_flags *rf); + int (*select_task_rq)(struct task_struct *p, int task_cpu, int flags); +@@ -2305,8 +2379,11 @@ struct sched_class { + * cannot assume the switched_from/switched_to pair is serialized by + * rq->lock. They are however serialized by p->pi_lock. + */ ++ void (*switching_to) (struct rq *this_rq, struct task_struct *task); + void (*switched_from)(struct rq *this_rq, struct task_struct *task); + void (*switched_to) (struct rq *this_rq, struct task_struct *task); ++ void (*reweight_task)(struct rq *this_rq, struct task_struct *task, ++ const struct load_weight *lw); + void (*prio_changed) (struct rq *this_rq, struct task_struct *task, + int oldprio); + +@@ -2355,19 +2432,54 @@ const struct sched_class name##_sched_class \ + extern struct sched_class __sched_class_highest[]; + extern struct sched_class __sched_class_lowest[]; + ++extern const struct sched_class stop_sched_class; ++extern const struct sched_class dl_sched_class; ++extern const struct sched_class rt_sched_class; ++extern const struct sched_class fair_sched_class; ++extern const struct sched_class idle_sched_class; ++ ++#ifdef CONFIG_SCHED_CLASS_EXT ++extern const struct sched_class ext_sched_class; ++ ++DECLARE_STATIC_KEY_FALSE(__scx_ops_enabled); /* SCX BPF scheduler loaded */ ++DECLARE_STATIC_KEY_FALSE(__scx_switched_all); /* all fair class tasks on SCX */ ++ ++#define scx_enabled() static_branch_unlikely(&__scx_ops_enabled) ++#define scx_switched_all() static_branch_unlikely(&__scx_switched_all) ++#else /* !CONFIG_SCHED_CLASS_EXT */ ++#define scx_enabled() false ++#define scx_switched_all() false ++#endif /* !CONFIG_SCHED_CLASS_EXT */ ++ ++/* ++ * Iterate only active classes. SCX can take over all fair tasks or be ++ * completely disabled. If the former, skip fair. If the latter, skip SCX. ++ */ ++static inline const struct sched_class *next_active_class(const struct sched_class *class) ++{ ++ class++; ++#ifdef CONFIG_SCHED_CLASS_EXT ++ if (scx_switched_all() && class == &fair_sched_class) ++ class++; ++ if (!scx_enabled() && class == &ext_sched_class) ++ class++; ++#endif ++ return class; ++} ++ + #define for_class_range(class, _from, _to) \ + for (class = (_from); class < (_to); class++) + + #define for_each_class(class) \ + for_class_range(class, __sched_class_highest, __sched_class_lowest) + +-#define sched_class_above(_a, _b) ((_a) < (_b)) ++#define for_active_class_range(class, _from, _to) \ ++ for (class = (_from); class != (_to); class = next_active_class(class)) + +-extern const struct sched_class stop_sched_class; +-extern const struct sched_class dl_sched_class; +-extern const struct sched_class rt_sched_class; +-extern const struct sched_class fair_sched_class; +-extern const struct sched_class idle_sched_class; ++#define for_each_active_class(class) \ ++ for_active_class_range(class, __sched_class_highest, __sched_class_lowest) ++ ++#define sched_class_above(_a, _b) ((_a) < (_b)) + + static inline bool sched_stop_runnable(struct rq *rq) + { +@@ -2464,7 +2576,7 @@ extern void init_sched_dl_class(void); + extern void init_sched_rt_class(void); + extern void init_sched_fair_class(void); + +-extern void reweight_task(struct task_struct *p, const struct load_weight *lw); ++extern void __setscheduler_prio(struct task_struct *p, int prio); + + extern void resched_curr(struct rq *rq); + extern void resched_cpu(int cpu); +@@ -2542,6 +2654,12 @@ static inline void sub_nr_running(struct rq *rq, unsigned count) + extern void activate_task(struct rq *rq, struct task_struct *p, int flags); + extern void deactivate_task(struct rq *rq, struct task_struct *p, int flags); + ++extern void check_class_changing(struct rq *rq, struct task_struct *p, ++ const struct sched_class *prev_class); ++extern void check_class_changed(struct rq *rq, struct task_struct *p, ++ const struct sched_class *prev_class, ++ int oldprio); ++ + extern void wakeup_preempt(struct rq *rq, struct task_struct *p, int flags); + + #if defined(CONFIG_PREEMPT_RT) || defined(CONFIG_CACHY) +@@ -3007,6 +3125,9 @@ static inline void cpufreq_update_util(struct rq *rq, unsigned int flags) {} + #endif + + #ifdef CONFIG_SMP ++ ++bool update_other_load_avgs(struct rq *rq); ++ + unsigned long effective_cpu_util(int cpu, unsigned long util_cfs, + unsigned long *min, + unsigned long *max); +@@ -3049,6 +3170,8 @@ static inline unsigned long cpu_util_rt(struct rq *rq) + { + return READ_ONCE(rq->avg_rt.util_avg); + } ++#else /* !CONFIG_SMP */ ++static inline bool update_other_load_avgs(struct rq *rq) { return false; } + #endif + + #ifdef CONFIG_UCLAMP_TASK +@@ -3481,4 +3604,24 @@ static inline void init_sched_mm_cid(struct task_struct *t) { } + extern u64 avg_vruntime(struct cfs_rq *cfs_rq); + extern int entity_eligible(struct cfs_rq *cfs_rq, struct sched_entity *se); + ++#ifdef CONFIG_SCHED_CLASS_EXT ++/* ++ * Used by SCX in the enable/disable paths to move tasks between sched_classes ++ * and establish invariants. ++ */ ++struct sched_enq_and_set_ctx { ++ struct task_struct *p; ++ int queue_flags; ++ bool queued; ++ bool running; ++}; ++ ++void sched_deq_and_put_task(struct task_struct *p, int queue_flags, ++ struct sched_enq_and_set_ctx *ctx); ++void sched_enq_and_set_task(struct sched_enq_and_set_ctx *ctx); ++ ++#endif /* CONFIG_SCHED_CLASS_EXT */ ++ ++#include "ext.h" ++ + #endif /* _KERNEL_SCHED_SCHED_H */ +diff --git a/lib/dump_stack.c b/lib/dump_stack.c +index 222c6d6c8281..9581ef4efec5 100644 +--- a/lib/dump_stack.c ++++ b/lib/dump_stack.c +@@ -68,6 +68,7 @@ void dump_stack_print_info(const char *log_lvl) + + print_worker_info(log_lvl, current); + print_stop_info(log_lvl, current); ++ print_scx_info(log_lvl, current); + } + + /** +diff --git a/tools/Makefile b/tools/Makefile +index 276f5d0d53a4..278d24723b74 100644 +--- a/tools/Makefile ++++ b/tools/Makefile +@@ -28,6 +28,7 @@ help: + @echo ' pci - PCI tools' + @echo ' perf - Linux performance measurement and analysis tool' + @echo ' selftests - various kernel selftests' ++ @echo ' sched_ext - sched_ext example schedulers' + @echo ' bootconfig - boot config tool' + @echo ' spi - spi tools' + @echo ' tmon - thermal monitoring and tuning tool' +@@ -91,6 +92,9 @@ perf: FORCE + $(Q)mkdir -p $(PERF_O) . + $(Q)$(MAKE) --no-print-directory -C perf O=$(PERF_O) subdir= + ++sched_ext: FORCE ++ $(call descend,sched_ext) ++ + selftests: FORCE + $(call descend,testing/$@) + +@@ -184,6 +188,9 @@ perf_clean: + $(Q)mkdir -p $(PERF_O) . + $(Q)$(MAKE) --no-print-directory -C perf O=$(PERF_O) subdir= clean + ++sched_ext_clean: ++ $(call descend,sched_ext,clean) ++ + selftests_clean: + $(call descend,testing/$(@:_clean=),clean) + +@@ -213,6 +220,7 @@ clean: acpi_clean counter_clean cpupower_clean hv_clean firewire_clean \ + mm_clean bpf_clean iio_clean x86_energy_perf_policy_clean tmon_clean \ + freefall_clean build_clean libbpf_clean libsubcmd_clean \ + gpio_clean objtool_clean leds_clean wmi_clean pci_clean firmware_clean debugging_clean \ +- intel-speed-select_clean tracing_clean thermal_clean thermometer_clean thermal-engine_clean ++ intel-speed-select_clean tracing_clean thermal_clean thermometer_clean thermal-engine_clean \ ++ sched_ext_clean + + .PHONY: FORCE +diff --git a/tools/sched_ext/.gitignore b/tools/sched_ext/.gitignore +new file mode 100644 +index 000000000000..d6264fe1c8cd +--- /dev/null ++++ b/tools/sched_ext/.gitignore +@@ -0,0 +1,2 @@ ++tools/ ++build/ +diff --git a/tools/sched_ext/Makefile b/tools/sched_ext/Makefile +new file mode 100644 +index 000000000000..bf7e108f5ae1 +--- /dev/null ++++ b/tools/sched_ext/Makefile +@@ -0,0 +1,246 @@ ++# SPDX-License-Identifier: GPL-2.0 ++# Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++include ../build/Build.include ++include ../scripts/Makefile.arch ++include ../scripts/Makefile.include ++ ++all: all_targets ++ ++ifneq ($(LLVM),) ++ifneq ($(filter %/,$(LLVM)),) ++LLVM_PREFIX := $(LLVM) ++else ifneq ($(filter -%,$(LLVM)),) ++LLVM_SUFFIX := $(LLVM) ++endif ++ ++CLANG_TARGET_FLAGS_arm := arm-linux-gnueabi ++CLANG_TARGET_FLAGS_arm64 := aarch64-linux-gnu ++CLANG_TARGET_FLAGS_hexagon := hexagon-linux-musl ++CLANG_TARGET_FLAGS_m68k := m68k-linux-gnu ++CLANG_TARGET_FLAGS_mips := mipsel-linux-gnu ++CLANG_TARGET_FLAGS_powerpc := powerpc64le-linux-gnu ++CLANG_TARGET_FLAGS_riscv := riscv64-linux-gnu ++CLANG_TARGET_FLAGS_s390 := s390x-linux-gnu ++CLANG_TARGET_FLAGS_x86 := x86_64-linux-gnu ++CLANG_TARGET_FLAGS := $(CLANG_TARGET_FLAGS_$(ARCH)) ++ ++ifeq ($(CROSS_COMPILE),) ++ifeq ($(CLANG_TARGET_FLAGS),) ++$(error Specify CROSS_COMPILE or add '--target=' option to lib.mk) ++else ++CLANG_FLAGS += --target=$(CLANG_TARGET_FLAGS) ++endif # CLANG_TARGET_FLAGS ++else ++CLANG_FLAGS += --target=$(notdir $(CROSS_COMPILE:%-=%)) ++endif # CROSS_COMPILE ++ ++CC := $(LLVM_PREFIX)clang$(LLVM_SUFFIX) $(CLANG_FLAGS) -fintegrated-as ++else ++CC := $(CROSS_COMPILE)gcc ++endif # LLVM ++ ++CURDIR := $(abspath .) ++TOOLSDIR := $(abspath ..) ++LIBDIR := $(TOOLSDIR)/lib ++BPFDIR := $(LIBDIR)/bpf ++TOOLSINCDIR := $(TOOLSDIR)/include ++BPFTOOLDIR := $(TOOLSDIR)/bpf/bpftool ++APIDIR := $(TOOLSINCDIR)/uapi ++GENDIR := $(abspath ../../include/generated) ++GENHDR := $(GENDIR)/autoconf.h ++ ++ifeq ($(O),) ++OUTPUT_DIR := $(CURDIR)/build ++else ++OUTPUT_DIR := $(O)/build ++endif # O ++OBJ_DIR := $(OUTPUT_DIR)/obj ++INCLUDE_DIR := $(OUTPUT_DIR)/include ++BPFOBJ_DIR := $(OBJ_DIR)/libbpf ++SCXOBJ_DIR := $(OBJ_DIR)/sched_ext ++BINDIR := $(OUTPUT_DIR)/bin ++BPFOBJ := $(BPFOBJ_DIR)/libbpf.a ++ifneq ($(CROSS_COMPILE),) ++HOST_BUILD_DIR := $(OBJ_DIR)/host ++HOST_OUTPUT_DIR := host-tools ++HOST_INCLUDE_DIR := $(HOST_OUTPUT_DIR)/include ++else ++HOST_BUILD_DIR := $(OBJ_DIR) ++HOST_OUTPUT_DIR := $(OUTPUT_DIR) ++HOST_INCLUDE_DIR := $(INCLUDE_DIR) ++endif ++HOST_BPFOBJ := $(HOST_BUILD_DIR)/libbpf/libbpf.a ++RESOLVE_BTFIDS := $(HOST_BUILD_DIR)/resolve_btfids/resolve_btfids ++DEFAULT_BPFTOOL := $(HOST_OUTPUT_DIR)/sbin/bpftool ++ ++VMLINUX_BTF_PATHS ?= $(if $(O),$(O)/vmlinux) \ ++ $(if $(KBUILD_OUTPUT),$(KBUILD_OUTPUT)/vmlinux) \ ++ ../../vmlinux \ ++ /sys/kernel/btf/vmlinux \ ++ /boot/vmlinux-$(shell uname -r) ++VMLINUX_BTF ?= $(abspath $(firstword $(wildcard $(VMLINUX_BTF_PATHS)))) ++ifeq ($(VMLINUX_BTF),) ++$(error Cannot find a vmlinux for VMLINUX_BTF at any of "$(VMLINUX_BTF_PATHS)") ++endif ++ ++BPFTOOL ?= $(DEFAULT_BPFTOOL) ++ ++ifneq ($(wildcard $(GENHDR)),) ++ GENFLAGS := -DHAVE_GENHDR ++endif ++ ++CFLAGS += -g -O2 -rdynamic -pthread -Wall -Werror $(GENFLAGS) \ ++ -I$(INCLUDE_DIR) -I$(GENDIR) -I$(LIBDIR) \ ++ -I$(TOOLSINCDIR) -I$(APIDIR) -I$(CURDIR)/include ++ ++# Silence some warnings when compiled with clang ++ifneq ($(LLVM),) ++CFLAGS += -Wno-unused-command-line-argument ++endif ++ ++LDFLAGS = -lelf -lz -lpthread ++ ++IS_LITTLE_ENDIAN = $(shell $(CC) -dM -E - &1 \ ++ | sed -n '/<...> search starts here:/,/End of search list./{ s| \(/.*\)|-idirafter \1|p }') \ ++$(shell $(1) -dM -E - $@ ++else ++ $(call msg,CP,,$@) ++ $(Q)cp "$(VMLINUX_H)" $@ ++endif ++ ++$(SCXOBJ_DIR)/%.bpf.o: %.bpf.c $(INCLUDE_DIR)/vmlinux.h include/scx/*.h \ ++ | $(BPFOBJ) $(SCXOBJ_DIR) ++ $(call msg,CLNG-BPF,,$(notdir $@)) ++ $(Q)$(CLANG) $(BPF_CFLAGS) -target bpf -c $< -o $@ ++ ++$(INCLUDE_DIR)/%.bpf.skel.h: $(SCXOBJ_DIR)/%.bpf.o $(INCLUDE_DIR)/vmlinux.h $(BPFTOOL) ++ $(eval sched=$(notdir $@)) ++ $(call msg,GEN-SKEL,,$(sched)) ++ $(Q)$(BPFTOOL) gen object $(<:.o=.linked1.o) $< ++ $(Q)$(BPFTOOL) gen object $(<:.o=.linked2.o) $(<:.o=.linked1.o) ++ $(Q)$(BPFTOOL) gen object $(<:.o=.linked3.o) $(<:.o=.linked2.o) ++ $(Q)diff $(<:.o=.linked2.o) $(<:.o=.linked3.o) ++ $(Q)$(BPFTOOL) gen skeleton $(<:.o=.linked3.o) name $(subst .bpf.skel.h,,$(sched)) > $@ ++ $(Q)$(BPFTOOL) gen subskeleton $(<:.o=.linked3.o) name $(subst .bpf.skel.h,,$(sched)) > $(@:.skel.h=.subskel.h) ++ ++SCX_COMMON_DEPS := include/scx/common.h include/scx/user_exit_info.h | $(BINDIR) ++ ++c-sched-targets = scx_simple scx_qmap scx_central ++ ++$(addprefix $(BINDIR)/,$(c-sched-targets)): \ ++ $(BINDIR)/%: \ ++ $(filter-out %.bpf.c,%.c) \ ++ $(INCLUDE_DIR)/%.bpf.skel.h \ ++ $(SCX_COMMON_DEPS) ++ $(eval sched=$(notdir $@)) ++ $(CC) $(CFLAGS) -c $(sched).c -o $(SCXOBJ_DIR)/$(sched).o ++ $(CC) -o $@ $(SCXOBJ_DIR)/$(sched).o $(HOST_BPFOBJ) $(LDFLAGS) ++ ++$(c-sched-targets): %: $(BINDIR)/% ++ ++install: all ++ $(Q)mkdir -p $(DESTDIR)/usr/local/bin/ ++ $(Q)cp $(BINDIR)/* $(DESTDIR)/usr/local/bin/ ++ ++clean: ++ rm -rf $(OUTPUT_DIR) $(HOST_OUTPUT_DIR) ++ rm -f *.o *.bpf.o *.bpf.skel.h *.bpf.subskel.h ++ rm -f $(c-sched-targets) ++ ++help: ++ @echo 'Building targets' ++ @echo '================' ++ @echo '' ++ @echo ' all - Compile all schedulers' ++ @echo '' ++ @echo 'Alternatively, you may compile individual schedulers:' ++ @echo '' ++ @printf ' %s\n' $(c-sched-targets) ++ @echo '' ++ @echo 'For any scheduler build target, you may specify an alternative' ++ @echo 'build output path with the O= environment variable. For example:' ++ @echo '' ++ @echo ' O=/tmp/sched_ext make all' ++ @echo '' ++ @echo 'will compile all schedulers, and emit the build artifacts to' ++ @echo '/tmp/sched_ext/build.' ++ @echo '' ++ @echo '' ++ @echo 'Installing targets' ++ @echo '==================' ++ @echo '' ++ @echo ' install - Compile and install all schedulers to /usr/bin.' ++ @echo ' You may specify the DESTDIR= environment variable' ++ @echo ' to indicate a prefix for /usr/bin. For example:' ++ @echo '' ++ @echo ' DESTDIR=/tmp/sched_ext make install' ++ @echo '' ++ @echo ' will build the schedulers in CWD/build, and' ++ @echo ' install the schedulers to /tmp/sched_ext/usr/bin.' ++ @echo '' ++ @echo '' ++ @echo 'Cleaning targets' ++ @echo '================' ++ @echo '' ++ @echo ' clean - Remove all generated files' ++ ++all_targets: $(c-sched-targets) ++ ++.PHONY: all all_targets $(c-sched-targets) clean help ++ ++# delete failed targets ++.DELETE_ON_ERROR: ++ ++# keep intermediate (.bpf.skel.h, .bpf.o, etc) targets ++.SECONDARY: +diff --git a/tools/sched_ext/README.md b/tools/sched_ext/README.md +new file mode 100644 +index 000000000000..8efe70cc4363 +--- /dev/null ++++ b/tools/sched_ext/README.md +@@ -0,0 +1,258 @@ ++SCHED_EXT EXAMPLE SCHEDULERS ++============================ ++ ++# Introduction ++ ++This directory contains a number of example sched_ext schedulers. These ++schedulers are meant to provide examples of different types of schedulers ++that can be built using sched_ext, and illustrate how various features of ++sched_ext can be used. ++ ++Some of the examples are performant, production-ready schedulers. That is, for ++the correct workload and with the correct tuning, they may be deployed in a ++production environment with acceptable or possibly even improved performance. ++Others are just examples that in practice, would not provide acceptable ++performance (though they could be improved to get there). ++ ++This README will describe these example schedulers, including describing the ++types of workloads or scenarios they're designed to accommodate, and whether or ++not they're production ready. For more details on any of these schedulers, ++please see the header comment in their .bpf.c file. ++ ++ ++# Compiling the examples ++ ++There are a few toolchain dependencies for compiling the example schedulers. ++ ++## Toolchain dependencies ++ ++1. clang >= 16.0.0 ++ ++The schedulers are BPF programs, and therefore must be compiled with clang. gcc ++is actively working on adding a BPF backend compiler as well, but are still ++missing some features such as BTF type tags which are necessary for using ++kptrs. ++ ++2. pahole >= 1.25 ++ ++You may need pahole in order to generate BTF from DWARF. ++ ++3. rust >= 1.70.0 ++ ++Rust schedulers uses features present in the rust toolchain >= 1.70.0. You ++should be able to use the stable build from rustup, but if that doesn't ++work, try using the rustup nightly build. ++ ++There are other requirements as well, such as make, but these are the main / ++non-trivial ones. ++ ++## Compiling the kernel ++ ++In order to run a sched_ext scheduler, you'll have to run a kernel compiled ++with the patches in this repository, and with a minimum set of necessary ++Kconfig options: ++ ++``` ++CONFIG_BPF=y ++CONFIG_SCHED_CLASS_EXT=y ++CONFIG_BPF_SYSCALL=y ++CONFIG_BPF_JIT=y ++CONFIG_DEBUG_INFO_BTF=y ++``` ++ ++It's also recommended that you also include the following Kconfig options: ++ ++``` ++CONFIG_BPF_JIT_ALWAYS_ON=y ++CONFIG_BPF_JIT_DEFAULT_ON=y ++CONFIG_PAHOLE_HAS_SPLIT_BTF=y ++CONFIG_PAHOLE_HAS_BTF_TAG=y ++``` ++ ++There is a `Kconfig` file in this directory whose contents you can append to ++your local `.config` file, as long as there are no conflicts with any existing ++options in the file. ++ ++## Getting a vmlinux.h file ++ ++You may notice that most of the example schedulers include a "vmlinux.h" file. ++This is a large, auto-generated header file that contains all of the types ++defined in some vmlinux binary that was compiled with ++[BTF](https://docs.kernel.org/bpf/btf.html) (i.e. with the BTF-related Kconfig ++options specified above). ++ ++The header file is created using `bpftool`, by passing it a vmlinux binary ++compiled with BTF as follows: ++ ++```bash ++$ bpftool btf dump file /path/to/vmlinux format c > vmlinux.h ++``` ++ ++`bpftool` analyzes all of the BTF encodings in the binary, and produces a ++header file that can be included by BPF programs to access those types. For ++example, using vmlinux.h allows a scheduler to access fields defined directly ++in vmlinux as follows: ++ ++```c ++#include "vmlinux.h" ++// vmlinux.h is also implicitly included by scx_common.bpf.h. ++#include "scx_common.bpf.h" ++ ++/* ++ * vmlinux.h provides definitions for struct task_struct and ++ * struct scx_enable_args. ++ */ ++void BPF_STRUCT_OPS(example_enable, struct task_struct *p, ++ struct scx_enable_args *args) ++{ ++ bpf_printk("Task %s enabled in example scheduler", p->comm); ++} ++ ++// vmlinux.h provides the definition for struct sched_ext_ops. ++SEC(".struct_ops.link") ++struct sched_ext_ops example_ops { ++ .enable = (void *)example_enable, ++ .name = "example", ++} ++``` ++ ++The scheduler build system will generate this vmlinux.h file as part of the ++scheduler build pipeline. It looks for a vmlinux file in the following ++dependency order: ++ ++1. If the O= environment variable is defined, at `$O/vmlinux` ++2. If the KBUILD_OUTPUT= environment variable is defined, at ++ `$KBUILD_OUTPUT/vmlinux` ++3. At `../../vmlinux` (i.e. at the root of the kernel tree where you're ++ compiling the schedulers) ++3. `/sys/kernel/btf/vmlinux` ++4. `/boot/vmlinux-$(uname -r)` ++ ++In other words, if you have compiled a kernel in your local repo, its vmlinux ++file will be used to generate vmlinux.h. Otherwise, it will be the vmlinux of ++the kernel you're currently running on. This means that if you're running on a ++kernel with sched_ext support, you may not need to compile a local kernel at ++all. ++ ++### Aside on CO-RE ++ ++One of the cooler features of BPF is that it supports ++[CO-RE](https://nakryiko.com/posts/bpf-core-reference-guide/) (Compile Once Run ++Everywhere). This feature allows you to reference fields inside of structs with ++types defined internal to the kernel, and not have to recompile if you load the ++BPF program on a different kernel with the field at a different offset. In our ++example above, we print out a task name with `p->comm`. CO-RE would perform ++relocations for that access when the program is loaded to ensure that it's ++referencing the correct offset for the currently running kernel. ++ ++## Compiling the schedulers ++ ++Once you have your toolchain setup, and a vmlinux that can be used to generate ++a full vmlinux.h file, you can compile the schedulers using `make`: ++ ++```bash ++$ make -j($nproc) ++``` ++ ++# Example schedulers ++ ++This directory contains the following example schedulers. These schedulers are ++for testing and demonstrating different aspects of sched_ext. While some may be ++useful in limited scenarios, they are not intended to be practical. ++ ++For more scheduler implementations, tools and documentation, visit ++https://github.com/sched-ext/scx. ++ ++## scx_simple ++ ++A simple scheduler that provides an example of a minimal sched_ext scheduler. ++scx_simple can be run in either global weighted vtime mode, or FIFO mode. ++ ++Though very simple, in limited scenarios, this scheduler can perform reasonably ++well on single-socket systems with a unified L3 cache. ++ ++## scx_qmap ++ ++Another simple, yet slightly more complex scheduler that provides an example of ++a basic weighted FIFO queuing policy. It also provides examples of some common ++useful BPF features, such as sleepable per-task storage allocation in the ++`ops.prep_enable()` callback, and using the `BPF_MAP_TYPE_QUEUE` map type to ++enqueue tasks. It also illustrates how core-sched support could be implemented. ++ ++## scx_central ++ ++A "central" scheduler where scheduling decisions are made from a single CPU. ++This scheduler illustrates how scheduling decisions can be dispatched from a ++single CPU, allowing other cores to run with infinite slices, without timer ++ticks, and without having to incur the overhead of making scheduling decisions. ++ ++The approach demonstrated by this scheduler may be useful for any workload that ++benefits from minimizing scheduling overhead and timer ticks. An example of ++where this could be particularly useful is running VMs, where running with ++infinite slices and no timer ticks allows the VM to avoid unnecessary expensive ++vmexits. ++ ++ ++# Troubleshooting ++ ++There are a number of common issues that you may run into when building the ++schedulers. We'll go over some of the common ones here. ++ ++## Build Failures ++ ++### Old version of clang ++ ++``` ++error: static assertion failed due to requirement 'SCX_DSQ_FLAG_BUILTIN': bpftool generated vmlinux.h is missing high bits for 64bit enums, upgrade clang and pahole ++ _Static_assert(SCX_DSQ_FLAG_BUILTIN, ++ ^~~~~~~~~~~~~~~~~~~~ ++1 error generated. ++``` ++ ++This means you built the kernel or the schedulers with an older version of ++clang than what's supported (i.e. older than 16.0.0). To remediate this: ++ ++1. `which clang` to make sure you're using a sufficiently new version of clang. ++ ++2. `make fullclean` in the root path of the repository, and rebuild the kernel ++ and schedulers. ++ ++3. Rebuild the kernel, and then your example schedulers. ++ ++The schedulers are also cleaned if you invoke `make mrproper` in the root ++directory of the tree. ++ ++### Stale kernel build / incomplete vmlinux.h file ++ ++As described above, you'll need a `vmlinux.h` file that was generated from a ++vmlinux built with BTF, and with sched_ext support enabled. If you don't, ++you'll see errors such as the following which indicate that a type being ++referenced in a scheduler is unknown: ++ ++``` ++/path/to/sched_ext/tools/sched_ext/user_exit_info.h:25:23: note: forward declaration of 'struct scx_exit_info' ++ ++const struct scx_exit_info *ei) ++ ++^ ++``` ++ ++In order to resolve this, please follow the steps above in ++[Getting a vmlinux.h file](#getting-a-vmlinuxh-file) in order to ensure your ++schedulers are using a vmlinux.h file that includes the requisite types. ++ ++## Misc ++ ++### llvm: [OFF] ++ ++You may see the following output when building the schedulers: ++ ++``` ++Auto-detecting system features: ++... clang-bpf-co-re: [ on ] ++... llvm: [ OFF ] ++... libcap: [ on ] ++... libbfd: [ on ] ++``` ++ ++Seeing `llvm: [ OFF ]` here is not an issue. You can safely ignore. +diff --git a/tools/sched_ext/include/bpf-compat/gnu/stubs.h b/tools/sched_ext/include/bpf-compat/gnu/stubs.h +new file mode 100644 +index 000000000000..ad7d139ce907 +--- /dev/null ++++ b/tools/sched_ext/include/bpf-compat/gnu/stubs.h +@@ -0,0 +1,11 @@ ++/* ++ * Dummy gnu/stubs.h. clang can end up including /usr/include/gnu/stubs.h when ++ * compiling BPF files although its content doesn't play any role. The file in ++ * turn includes stubs-64.h or stubs-32.h depending on whether __x86_64__ is ++ * defined. When compiling a BPF source, __x86_64__ isn't set and thus ++ * stubs-32.h is selected. However, the file is not there if the system doesn't ++ * have 32bit glibc devel package installed leading to a build failure. ++ * ++ * The problem is worked around by making this file available in the include ++ * search paths before the system one when building BPF. ++ */ +diff --git a/tools/sched_ext/include/scx/common.bpf.h b/tools/sched_ext/include/scx/common.bpf.h +new file mode 100644 +index 000000000000..20280df62857 +--- /dev/null ++++ b/tools/sched_ext/include/scx/common.bpf.h +@@ -0,0 +1,401 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#ifndef __SCX_COMMON_BPF_H ++#define __SCX_COMMON_BPF_H ++ ++#include "vmlinux.h" ++#include ++#include ++#include ++#include "user_exit_info.h" ++ ++#define PF_WQ_WORKER 0x00000020 /* I'm a workqueue worker */ ++#define PF_KTHREAD 0x00200000 /* I am a kernel thread */ ++#define PF_EXITING 0x00000004 ++#define CLOCK_MONOTONIC 1 ++ ++/* ++ * Earlier versions of clang/pahole lost upper 32bits in 64bit enums which can ++ * lead to really confusing misbehaviors. Let's trigger a build failure. ++ */ ++static inline void ___vmlinux_h_sanity_check___(void) ++{ ++ _Static_assert(SCX_DSQ_FLAG_BUILTIN, ++ "bpftool generated vmlinux.h is missing high bits for 64bit enums, upgrade clang and pahole"); ++} ++ ++s32 scx_bpf_create_dsq(u64 dsq_id, s32 node) __ksym; ++s32 scx_bpf_select_cpu_dfl(struct task_struct *p, s32 prev_cpu, u64 wake_flags, bool *is_idle) __ksym; ++void scx_bpf_dispatch(struct task_struct *p, u64 dsq_id, u64 slice, u64 enq_flags) __ksym; ++void scx_bpf_dispatch_vtime(struct task_struct *p, u64 dsq_id, u64 slice, u64 vtime, u64 enq_flags) __ksym; ++u32 scx_bpf_dispatch_nr_slots(void) __ksym; ++void scx_bpf_dispatch_cancel(void) __ksym; ++bool scx_bpf_consume(u64 dsq_id) __ksym; ++u32 scx_bpf_reenqueue_local(void) __ksym; ++void scx_bpf_kick_cpu(s32 cpu, u64 flags) __ksym; ++s32 scx_bpf_dsq_nr_queued(u64 dsq_id) __ksym; ++void scx_bpf_destroy_dsq(u64 dsq_id) __ksym; ++int bpf_iter_scx_dsq_new(struct bpf_iter_scx_dsq *it, u64 dsq_id, u64 flags) __ksym __weak; ++struct task_struct *bpf_iter_scx_dsq_next(struct bpf_iter_scx_dsq *it) __ksym __weak; ++void bpf_iter_scx_dsq_destroy(struct bpf_iter_scx_dsq *it) __ksym __weak; ++void scx_bpf_exit_bstr(s64 exit_code, char *fmt, unsigned long long *data, u32 data__sz) __ksym __weak; ++void scx_bpf_error_bstr(char *fmt, unsigned long long *data, u32 data_len) __ksym; ++void scx_bpf_dump_bstr(char *fmt, unsigned long long *data, u32 data_len) __ksym __weak; ++u32 scx_bpf_cpuperf_cap(s32 cpu) __ksym __weak; ++u32 scx_bpf_cpuperf_cur(s32 cpu) __ksym __weak; ++void scx_bpf_cpuperf_set(s32 cpu, u32 perf) __ksym __weak; ++u32 scx_bpf_nr_cpu_ids(void) __ksym __weak; ++const struct cpumask *scx_bpf_get_possible_cpumask(void) __ksym __weak; ++const struct cpumask *scx_bpf_get_online_cpumask(void) __ksym __weak; ++void scx_bpf_put_cpumask(const struct cpumask *cpumask) __ksym __weak; ++const struct cpumask *scx_bpf_get_idle_cpumask(void) __ksym; ++const struct cpumask *scx_bpf_get_idle_smtmask(void) __ksym; ++void scx_bpf_put_idle_cpumask(const struct cpumask *cpumask) __ksym; ++bool scx_bpf_test_and_clear_cpu_idle(s32 cpu) __ksym; ++s32 scx_bpf_pick_idle_cpu(const cpumask_t *cpus_allowed, u64 flags) __ksym; ++s32 scx_bpf_pick_any_cpu(const cpumask_t *cpus_allowed, u64 flags) __ksym; ++bool scx_bpf_task_running(const struct task_struct *p) __ksym; ++s32 scx_bpf_task_cpu(const struct task_struct *p) __ksym; ++struct rq *scx_bpf_cpu_rq(s32 cpu) __ksym; ++ ++static inline __attribute__((format(printf, 1, 2))) ++void ___scx_bpf_bstr_format_checker(const char *fmt, ...) {} ++ ++/* ++ * Helper macro for initializing the fmt and variadic argument inputs to both ++ * bstr exit kfuncs. Callers to this function should use ___fmt and ___param to ++ * refer to the initialized list of inputs to the bstr kfunc. ++ */ ++#define scx_bpf_bstr_preamble(fmt, args...) \ ++ static char ___fmt[] = fmt; \ ++ /* \ ++ * Note that __param[] must have at least one \ ++ * element to keep the verifier happy. \ ++ */ \ ++ unsigned long long ___param[___bpf_narg(args) ?: 1] = {}; \ ++ \ ++ _Pragma("GCC diagnostic push") \ ++ _Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \ ++ ___bpf_fill(___param, args); \ ++ _Pragma("GCC diagnostic pop") \ ++ ++/* ++ * scx_bpf_exit() wraps the scx_bpf_exit_bstr() kfunc with variadic arguments ++ * instead of an array of u64. Using this macro will cause the scheduler to ++ * exit cleanly with the specified exit code being passed to user space. ++ */ ++#define scx_bpf_exit(code, fmt, args...) \ ++({ \ ++ scx_bpf_bstr_preamble(fmt, args) \ ++ scx_bpf_exit_bstr(code, ___fmt, ___param, sizeof(___param)); \ ++ ___scx_bpf_bstr_format_checker(fmt, ##args); \ ++}) ++ ++/* ++ * scx_bpf_error() wraps the scx_bpf_error_bstr() kfunc with variadic arguments ++ * instead of an array of u64. Invoking this macro will cause the scheduler to ++ * exit in an erroneous state, with diagnostic information being passed to the ++ * user. ++ */ ++#define scx_bpf_error(fmt, args...) \ ++({ \ ++ scx_bpf_bstr_preamble(fmt, args) \ ++ scx_bpf_error_bstr(___fmt, ___param, sizeof(___param)); \ ++ ___scx_bpf_bstr_format_checker(fmt, ##args); \ ++}) ++ ++/* ++ * scx_bpf_dump() wraps the scx_bpf_dump_bstr() kfunc with variadic arguments ++ * instead of an array of u64. To be used from ops.dump() and friends. ++ */ ++#define scx_bpf_dump(fmt, args...) \ ++({ \ ++ scx_bpf_bstr_preamble(fmt, args) \ ++ scx_bpf_dump_bstr(___fmt, ___param, sizeof(___param)); \ ++ ___scx_bpf_bstr_format_checker(fmt, ##args); \ ++}) ++ ++#define BPF_STRUCT_OPS(name, args...) \ ++SEC("struct_ops/"#name) \ ++BPF_PROG(name, ##args) ++ ++#define BPF_STRUCT_OPS_SLEEPABLE(name, args...) \ ++SEC("struct_ops.s/"#name) \ ++BPF_PROG(name, ##args) ++ ++/** ++ * RESIZABLE_ARRAY - Generates annotations for an array that may be resized ++ * @elfsec: the data section of the BPF program in which to place the array ++ * @arr: the name of the array ++ * ++ * libbpf has an API for setting map value sizes. Since data sections (i.e. ++ * bss, data, rodata) themselves are maps, a data section can be resized. If ++ * a data section has an array as its last element, the BTF info for that ++ * array will be adjusted so that length of the array is extended to meet the ++ * new length of the data section. This macro annotates an array to have an ++ * element count of one with the assumption that this array can be resized ++ * within the userspace program. It also annotates the section specifier so ++ * this array exists in a custom sub data section which can be resized ++ * independently. ++ * ++ * See RESIZE_ARRAY() for the userspace convenience macro for resizing an ++ * array declared with RESIZABLE_ARRAY(). ++ */ ++#define RESIZABLE_ARRAY(elfsec, arr) arr[1] SEC("."#elfsec"."#arr) ++ ++/** ++ * MEMBER_VPTR - Obtain the verified pointer to a struct or array member ++ * @base: struct or array to index ++ * @member: dereferenced member (e.g. .field, [idx0][idx1], .field[idx0] ...) ++ * ++ * The verifier often gets confused by the instruction sequence the compiler ++ * generates for indexing struct fields or arrays. This macro forces the ++ * compiler to generate a code sequence which first calculates the byte offset, ++ * checks it against the struct or array size and add that byte offset to ++ * generate the pointer to the member to help the verifier. ++ * ++ * Ideally, we want to abort if the calculated offset is out-of-bounds. However, ++ * BPF currently doesn't support abort, so evaluate to %NULL instead. The caller ++ * must check for %NULL and take appropriate action to appease the verifier. To ++ * avoid confusing the verifier, it's best to check for %NULL and dereference ++ * immediately. ++ * ++ * vptr = MEMBER_VPTR(my_array, [i][j]); ++ * if (!vptr) ++ * return error; ++ * *vptr = new_value; ++ * ++ * sizeof(@base) should encompass the memory area to be accessed and thus can't ++ * be a pointer to the area. Use `MEMBER_VPTR(*ptr, .member)` instead of ++ * `MEMBER_VPTR(ptr, ->member)`. ++ */ ++#define MEMBER_VPTR(base, member) (typeof((base) member) *) \ ++({ \ ++ u64 __base = (u64)&(base); \ ++ u64 __addr = (u64)&((base) member) - __base; \ ++ _Static_assert(sizeof(base) >= sizeof((base) member), \ ++ "@base is smaller than @member, is @base a pointer?"); \ ++ asm volatile ( \ ++ "if %0 <= %[max] goto +2\n" \ ++ "%0 = 0\n" \ ++ "goto +1\n" \ ++ "%0 += %1\n" \ ++ : "+r"(__addr) \ ++ : "r"(__base), \ ++ [max]"i"(sizeof(base) - sizeof((base) member))); \ ++ __addr; \ ++}) ++ ++/** ++ * ARRAY_ELEM_PTR - Obtain the verified pointer to an array element ++ * @arr: array to index into ++ * @i: array index ++ * @n: number of elements in array ++ * ++ * Similar to MEMBER_VPTR() but is intended for use with arrays where the ++ * element count needs to be explicit. ++ * It can be used in cases where a global array is defined with an initial ++ * size but is intended to be be resized before loading the BPF program. ++ * Without this version of the macro, MEMBER_VPTR() will use the compile time ++ * size of the array to compute the max, which will result in rejection by ++ * the verifier. ++ */ ++#define ARRAY_ELEM_PTR(arr, i, n) (typeof(arr[i]) *) \ ++({ \ ++ u64 __base = (u64)arr; \ ++ u64 __addr = (u64)&(arr[i]) - __base; \ ++ asm volatile ( \ ++ "if %0 <= %[max] goto +2\n" \ ++ "%0 = 0\n" \ ++ "goto +1\n" \ ++ "%0 += %1\n" \ ++ : "+r"(__addr) \ ++ : "r"(__base), \ ++ [max]"r"(sizeof(arr[0]) * ((n) - 1))); \ ++ __addr; \ ++}) ++ ++ ++/* ++ * BPF declarations and helpers ++ */ ++ ++/* list and rbtree */ ++#define __contains(name, node) __attribute__((btf_decl_tag("contains:" #name ":" #node))) ++#define private(name) SEC(".data." #name) __hidden __attribute__((aligned(8))) ++ ++void *bpf_obj_new_impl(__u64 local_type_id, void *meta) __ksym; ++void bpf_obj_drop_impl(void *kptr, void *meta) __ksym; ++ ++#define bpf_obj_new(type) ((type *)bpf_obj_new_impl(bpf_core_type_id_local(type), NULL)) ++#define bpf_obj_drop(kptr) bpf_obj_drop_impl(kptr, NULL) ++ ++void bpf_list_push_front(struct bpf_list_head *head, struct bpf_list_node *node) __ksym; ++void bpf_list_push_back(struct bpf_list_head *head, struct bpf_list_node *node) __ksym; ++struct bpf_list_node *bpf_list_pop_front(struct bpf_list_head *head) __ksym; ++struct bpf_list_node *bpf_list_pop_back(struct bpf_list_head *head) __ksym; ++struct bpf_rb_node *bpf_rbtree_remove(struct bpf_rb_root *root, ++ struct bpf_rb_node *node) __ksym; ++int bpf_rbtree_add_impl(struct bpf_rb_root *root, struct bpf_rb_node *node, ++ bool (less)(struct bpf_rb_node *a, const struct bpf_rb_node *b), ++ void *meta, __u64 off) __ksym; ++#define bpf_rbtree_add(head, node, less) bpf_rbtree_add_impl(head, node, less, NULL, 0) ++ ++struct bpf_rb_node *bpf_rbtree_first(struct bpf_rb_root *root) __ksym; ++ ++void *bpf_refcount_acquire_impl(void *kptr, void *meta) __ksym; ++#define bpf_refcount_acquire(kptr) bpf_refcount_acquire_impl(kptr, NULL) ++ ++/* task */ ++struct task_struct *bpf_task_from_pid(s32 pid) __ksym; ++struct task_struct *bpf_task_acquire(struct task_struct *p) __ksym; ++void bpf_task_release(struct task_struct *p) __ksym; ++ ++/* cgroup */ ++struct cgroup *bpf_cgroup_ancestor(struct cgroup *cgrp, int level) __ksym; ++void bpf_cgroup_release(struct cgroup *cgrp) __ksym; ++struct cgroup *bpf_cgroup_from_id(u64 cgid) __ksym; ++ ++/* css iteration */ ++struct bpf_iter_css; ++struct cgroup_subsys_state; ++extern int bpf_iter_css_new(struct bpf_iter_css *it, ++ struct cgroup_subsys_state *start, ++ unsigned int flags) __weak __ksym; ++extern struct cgroup_subsys_state * ++bpf_iter_css_next(struct bpf_iter_css *it) __weak __ksym; ++extern void bpf_iter_css_destroy(struct bpf_iter_css *it) __weak __ksym; ++ ++/* cpumask */ ++struct bpf_cpumask *bpf_cpumask_create(void) __ksym; ++struct bpf_cpumask *bpf_cpumask_acquire(struct bpf_cpumask *cpumask) __ksym; ++void bpf_cpumask_release(struct bpf_cpumask *cpumask) __ksym; ++u32 bpf_cpumask_first(const struct cpumask *cpumask) __ksym; ++u32 bpf_cpumask_first_zero(const struct cpumask *cpumask) __ksym; ++void bpf_cpumask_set_cpu(u32 cpu, struct bpf_cpumask *cpumask) __ksym; ++void bpf_cpumask_clear_cpu(u32 cpu, struct bpf_cpumask *cpumask) __ksym; ++bool bpf_cpumask_test_cpu(u32 cpu, const struct cpumask *cpumask) __ksym; ++bool bpf_cpumask_test_and_set_cpu(u32 cpu, struct bpf_cpumask *cpumask) __ksym; ++bool bpf_cpumask_test_and_clear_cpu(u32 cpu, struct bpf_cpumask *cpumask) __ksym; ++void bpf_cpumask_setall(struct bpf_cpumask *cpumask) __ksym; ++void bpf_cpumask_clear(struct bpf_cpumask *cpumask) __ksym; ++bool bpf_cpumask_and(struct bpf_cpumask *dst, const struct cpumask *src1, ++ const struct cpumask *src2) __ksym; ++void bpf_cpumask_or(struct bpf_cpumask *dst, const struct cpumask *src1, ++ const struct cpumask *src2) __ksym; ++void bpf_cpumask_xor(struct bpf_cpumask *dst, const struct cpumask *src1, ++ const struct cpumask *src2) __ksym; ++bool bpf_cpumask_equal(const struct cpumask *src1, const struct cpumask *src2) __ksym; ++bool bpf_cpumask_intersects(const struct cpumask *src1, const struct cpumask *src2) __ksym; ++bool bpf_cpumask_subset(const struct cpumask *src1, const struct cpumask *src2) __ksym; ++bool bpf_cpumask_empty(const struct cpumask *cpumask) __ksym; ++bool bpf_cpumask_full(const struct cpumask *cpumask) __ksym; ++void bpf_cpumask_copy(struct bpf_cpumask *dst, const struct cpumask *src) __ksym; ++u32 bpf_cpumask_any_distribute(const struct cpumask *cpumask) __ksym; ++u32 bpf_cpumask_any_and_distribute(const struct cpumask *src1, ++ const struct cpumask *src2) __ksym; ++ ++/* rcu */ ++void bpf_rcu_read_lock(void) __ksym; ++void bpf_rcu_read_unlock(void) __ksym; ++ ++ ++/* ++ * Other helpers ++ */ ++ ++/* useful compiler attributes */ ++#define likely(x) __builtin_expect(!!(x), 1) ++#define unlikely(x) __builtin_expect(!!(x), 0) ++#define __maybe_unused __attribute__((__unused__)) ++ ++/* ++ * READ/WRITE_ONCE() are from kernel (include/asm-generic/rwonce.h). They ++ * prevent compiler from caching, redoing or reordering reads or writes. ++ */ ++typedef __u8 __attribute__((__may_alias__)) __u8_alias_t; ++typedef __u16 __attribute__((__may_alias__)) __u16_alias_t; ++typedef __u32 __attribute__((__may_alias__)) __u32_alias_t; ++typedef __u64 __attribute__((__may_alias__)) __u64_alias_t; ++ ++static __always_inline void __read_once_size(const volatile void *p, void *res, int size) ++{ ++ switch (size) { ++ case 1: *(__u8_alias_t *) res = *(volatile __u8_alias_t *) p; break; ++ case 2: *(__u16_alias_t *) res = *(volatile __u16_alias_t *) p; break; ++ case 4: *(__u32_alias_t *) res = *(volatile __u32_alias_t *) p; break; ++ case 8: *(__u64_alias_t *) res = *(volatile __u64_alias_t *) p; break; ++ default: ++ barrier(); ++ __builtin_memcpy((void *)res, (const void *)p, size); ++ barrier(); ++ } ++} ++ ++static __always_inline void __write_once_size(volatile void *p, void *res, int size) ++{ ++ switch (size) { ++ case 1: *(volatile __u8_alias_t *) p = *(__u8_alias_t *) res; break; ++ case 2: *(volatile __u16_alias_t *) p = *(__u16_alias_t *) res; break; ++ case 4: *(volatile __u32_alias_t *) p = *(__u32_alias_t *) res; break; ++ case 8: *(volatile __u64_alias_t *) p = *(__u64_alias_t *) res; break; ++ default: ++ barrier(); ++ __builtin_memcpy((void *)p, (const void *)res, size); ++ barrier(); ++ } ++} ++ ++#define READ_ONCE(x) \ ++({ \ ++ union { typeof(x) __val; char __c[1]; } __u = \ ++ { .__c = { 0 } }; \ ++ __read_once_size(&(x), __u.__c, sizeof(x)); \ ++ __u.__val; \ ++}) ++ ++#define WRITE_ONCE(x, val) \ ++({ \ ++ union { typeof(x) __val; char __c[1]; } __u = \ ++ { .__val = (val) }; \ ++ __write_once_size(&(x), __u.__c, sizeof(x)); \ ++ __u.__val; \ ++}) ++ ++/* ++ * log2_u32 - Compute the base 2 logarithm of a 32-bit exponential value. ++ * @v: The value for which we're computing the base 2 logarithm. ++ */ ++static inline u32 log2_u32(u32 v) ++{ ++ u32 r; ++ u32 shift; ++ ++ r = (v > 0xFFFF) << 4; v >>= r; ++ shift = (v > 0xFF) << 3; v >>= shift; r |= shift; ++ shift = (v > 0xF) << 2; v >>= shift; r |= shift; ++ shift = (v > 0x3) << 1; v >>= shift; r |= shift; ++ r |= (v >> 1); ++ return r; ++} ++ ++/* ++ * log2_u64 - Compute the base 2 logarithm of a 64-bit exponential value. ++ * @v: The value for which we're computing the base 2 logarithm. ++ */ ++static inline u32 log2_u64(u64 v) ++{ ++ u32 hi = v >> 32; ++ if (hi) ++ return log2_u32(hi) + 32 + 1; ++ else ++ return log2_u32(v) + 1; ++} ++ ++#include "compat.bpf.h" ++ ++#endif /* __SCX_COMMON_BPF_H */ +diff --git a/tools/sched_ext/include/scx/common.h b/tools/sched_ext/include/scx/common.h +new file mode 100644 +index 000000000000..5b0f90152152 +--- /dev/null ++++ b/tools/sched_ext/include/scx/common.h +@@ -0,0 +1,75 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 Tejun Heo ++ * Copyright (c) 2023 David Vernet ++ */ ++#ifndef __SCHED_EXT_COMMON_H ++#define __SCHED_EXT_COMMON_H ++ ++#ifdef __KERNEL__ ++#error "Should not be included by BPF programs" ++#endif ++ ++#include ++#include ++#include ++#include ++#include ++ ++typedef uint8_t u8; ++typedef uint16_t u16; ++typedef uint32_t u32; ++typedef uint64_t u64; ++typedef int8_t s8; ++typedef int16_t s16; ++typedef int32_t s32; ++typedef int64_t s64; ++ ++#define SCX_BUG(__fmt, ...) \ ++ do { \ ++ fprintf(stderr, "[SCX_BUG] %s:%d", __FILE__, __LINE__); \ ++ if (errno) \ ++ fprintf(stderr, " (%s)\n", strerror(errno)); \ ++ else \ ++ fprintf(stderr, "\n"); \ ++ fprintf(stderr, __fmt __VA_OPT__(,) __VA_ARGS__); \ ++ fprintf(stderr, "\n"); \ ++ \ ++ exit(EXIT_FAILURE); \ ++ } while (0) ++ ++#define SCX_BUG_ON(__cond, __fmt, ...) \ ++ do { \ ++ if (__cond) \ ++ SCX_BUG((__fmt) __VA_OPT__(,) __VA_ARGS__); \ ++ } while (0) ++ ++/** ++ * RESIZE_ARRAY - Convenience macro for resizing a BPF array ++ * @__skel: the skeleton containing the array ++ * @elfsec: the data section of the BPF program in which the array exists ++ * @arr: the name of the array ++ * @n: the desired array element count ++ * ++ * For BPF arrays declared with RESIZABLE_ARRAY(), this macro performs two ++ * operations. It resizes the map which corresponds to the custom data ++ * section that contains the target array. As a side effect, the BTF info for ++ * the array is adjusted so that the array length is sized to cover the new ++ * data section size. The second operation is reassigning the skeleton pointer ++ * for that custom data section so that it points to the newly memory mapped ++ * region. ++ */ ++#define RESIZE_ARRAY(__skel, elfsec, arr, n) \ ++ do { \ ++ size_t __sz; \ ++ bpf_map__set_value_size((__skel)->maps.elfsec##_##arr, \ ++ sizeof((__skel)->elfsec##_##arr->arr[0]) * (n)); \ ++ (__skel)->elfsec##_##arr = \ ++ bpf_map__initial_value((__skel)->maps.elfsec##_##arr, &__sz); \ ++ } while (0) ++ ++#include "user_exit_info.h" ++#include "compat.h" ++ ++#endif /* __SCHED_EXT_COMMON_H */ +diff --git a/tools/sched_ext/include/scx/compat.bpf.h b/tools/sched_ext/include/scx/compat.bpf.h +new file mode 100644 +index 000000000000..3d2fe1208900 +--- /dev/null ++++ b/tools/sched_ext/include/scx/compat.bpf.h +@@ -0,0 +1,28 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 Tejun Heo ++ * Copyright (c) 2024 David Vernet ++ */ ++#ifndef __SCX_COMPAT_BPF_H ++#define __SCX_COMPAT_BPF_H ++ ++#define __COMPAT_ENUM_OR_ZERO(__type, __ent) \ ++({ \ ++ __type __ret = 0; \ ++ if (bpf_core_enum_value_exists(__type, __ent)) \ ++ __ret = __ent; \ ++ __ret; \ ++}) ++ ++/* ++ * Define sched_ext_ops. This may be expanded to define multiple variants for ++ * backward compatibility. See compat.h::SCX_OPS_LOAD/ATTACH(). ++ */ ++#define SCX_OPS_DEFINE(__name, ...) \ ++ SEC(".struct_ops.link") \ ++ struct sched_ext_ops __name = { \ ++ __VA_ARGS__, \ ++ }; ++ ++#endif /* __SCX_COMPAT_BPF_H */ +diff --git a/tools/sched_ext/include/scx/compat.h b/tools/sched_ext/include/scx/compat.h +new file mode 100644 +index 000000000000..1bf8eddf20c2 +--- /dev/null ++++ b/tools/sched_ext/include/scx/compat.h +@@ -0,0 +1,187 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 Tejun Heo ++ * Copyright (c) 2024 David Vernet ++ */ ++#ifndef __SCX_COMPAT_H ++#define __SCX_COMPAT_H ++ ++#include ++#include ++#include ++#include ++ ++struct btf *__COMPAT_vmlinux_btf __attribute__((weak)); ++ ++static inline void __COMPAT_load_vmlinux_btf(void) ++{ ++ if (!__COMPAT_vmlinux_btf) { ++ __COMPAT_vmlinux_btf = btf__load_vmlinux_btf(); ++ SCX_BUG_ON(!__COMPAT_vmlinux_btf, "btf__load_vmlinux_btf()"); ++ } ++} ++ ++static inline bool __COMPAT_read_enum(const char *type, const char *name, u64 *v) ++{ ++ const struct btf_type *t; ++ const char *n; ++ s32 tid; ++ int i; ++ ++ __COMPAT_load_vmlinux_btf(); ++ ++ tid = btf__find_by_name(__COMPAT_vmlinux_btf, type); ++ if (tid < 0) ++ return false; ++ ++ t = btf__type_by_id(__COMPAT_vmlinux_btf, tid); ++ SCX_BUG_ON(!t, "btf__type_by_id(%d)", tid); ++ ++ if (btf_is_enum(t)) { ++ struct btf_enum *e = btf_enum(t); ++ ++ for (i = 0; i < BTF_INFO_VLEN(t->info); i++) { ++ n = btf__name_by_offset(__COMPAT_vmlinux_btf, e[i].name_off); ++ SCX_BUG_ON(!n, "btf__name_by_offset()"); ++ if (!strcmp(n, name)) { ++ *v = e[i].val; ++ return true; ++ } ++ } ++ } else if (btf_is_enum64(t)) { ++ struct btf_enum64 *e = btf_enum64(t); ++ ++ for (i = 0; i < BTF_INFO_VLEN(t->info); i++) { ++ n = btf__name_by_offset(__COMPAT_vmlinux_btf, e[i].name_off); ++ SCX_BUG_ON(!n, "btf__name_by_offset()"); ++ if (!strcmp(n, name)) { ++ *v = btf_enum64_value(&e[i]); ++ return true; ++ } ++ } ++ } ++ ++ return false; ++} ++ ++#define __COMPAT_ENUM_OR_ZERO(__type, __ent) \ ++({ \ ++ u64 __val = 0; \ ++ __COMPAT_read_enum(__type, __ent, &__val); \ ++ __val; \ ++}) ++ ++static inline bool __COMPAT_has_ksym(const char *ksym) ++{ ++ __COMPAT_load_vmlinux_btf(); ++ return btf__find_by_name(__COMPAT_vmlinux_btf, ksym) >= 0; ++} ++ ++static inline bool __COMPAT_struct_has_field(const char *type, const char *field) ++{ ++ const struct btf_type *t; ++ const struct btf_member *m; ++ const char *n; ++ s32 tid; ++ int i; ++ ++ __COMPAT_load_vmlinux_btf(); ++ tid = btf__find_by_name_kind(__COMPAT_vmlinux_btf, type, BTF_KIND_STRUCT); ++ if (tid < 0) ++ return false; ++ ++ t = btf__type_by_id(__COMPAT_vmlinux_btf, tid); ++ SCX_BUG_ON(!t, "btf__type_by_id(%d)", tid); ++ ++ m = btf_members(t); ++ ++ for (i = 0; i < BTF_INFO_VLEN(t->info); i++) { ++ n = btf__name_by_offset(__COMPAT_vmlinux_btf, m[i].name_off); ++ SCX_BUG_ON(!n, "btf__name_by_offset()"); ++ if (!strcmp(n, field)) ++ return true; ++ } ++ ++ return false; ++} ++ ++#define SCX_OPS_SWITCH_PARTIAL \ ++ __COMPAT_ENUM_OR_ZERO("scx_ops_flags", "SCX_OPS_SWITCH_PARTIAL") ++ ++static inline long scx_hotplug_seq(void) ++{ ++ int fd; ++ char buf[32]; ++ ssize_t len; ++ long val; ++ ++ fd = open("/sys/kernel/sched_ext/hotplug_seq", O_RDONLY); ++ if (fd < 0) ++ return -ENOENT; ++ ++ len = read(fd, buf, sizeof(buf) - 1); ++ SCX_BUG_ON(len <= 0, "read failed (%ld)", len); ++ buf[len] = 0; ++ close(fd); ++ ++ val = strtoul(buf, NULL, 10); ++ SCX_BUG_ON(val < 0, "invalid num hotplug events: %lu", val); ++ ++ return val; ++} ++ ++/* ++ * struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE() ++ * is used to define ops and compat.h::SCX_OPS_LOAD/ATTACH() are used to load ++ * and attach it, backward compatibility is automatically maintained where ++ * reasonable. ++ * ++ * ec7e3b0463e1 ("implement-ops") in https://github.com/sched-ext/sched_ext is ++ * the current minimum required kernel version. ++ */ ++#define SCX_OPS_OPEN(__ops_name, __scx_name) ({ \ ++ struct __scx_name *__skel; \ ++ \ ++ SCX_BUG_ON(!__COMPAT_struct_has_field("sched_ext_ops", "dump"), \ ++ "sched_ext_ops.dump() missing, kernel too old?"); \ ++ \ ++ __skel = __scx_name##__open(); \ ++ SCX_BUG_ON(!__skel, "Could not open " #__scx_name); \ ++ __skel->struct_ops.__ops_name->hotplug_seq = scx_hotplug_seq(); \ ++ __skel; \ ++}) ++ ++#define SCX_OPS_LOAD(__skel, __ops_name, __scx_name, __uei_name) ({ \ ++ UEI_SET_SIZE(__skel, __ops_name, __uei_name); \ ++ SCX_BUG_ON(__scx_name##__load((__skel)), "Failed to load skel"); \ ++}) ++ ++/* ++ * New versions of bpftool now emit additional link placeholders for BPF maps, ++ * and set up BPF skeleton in such a way that libbpf will auto-attach BPF maps ++ * automatically, assumming libbpf is recent enough (v1.5+). Old libbpf will do ++ * nothing with those links and won't attempt to auto-attach maps. ++ * ++ * To maintain compatibility with older libbpf while avoiding trying to attach ++ * twice, disable the autoattach feature on newer libbpf. ++ */ ++/* BACKPORT - bpf_mpa__set_autoattach() not available yet, commented out */ ++/*#if LIBBPF_MAJOR_VERSION > 1 || \ ++ (LIBBPF_MAJOR_VERSION == 1 && LIBBPF_MINOR_VERSION >= 5) ++#define __SCX_OPS_DISABLE_AUTOATTACH(__skel, __ops_name) \ ++ bpf_map__set_autoattach((__skel)->maps.__ops_name, false) ++#else*/ ++#define __SCX_OPS_DISABLE_AUTOATTACH(__skel, __ops_name) do {} while (0) ++/*#endif*/ ++ ++#define SCX_OPS_ATTACH(__skel, __ops_name, __scx_name) ({ \ ++ struct bpf_link *__link; \ ++ __SCX_OPS_DISABLE_AUTOATTACH(__skel, __ops_name); \ ++ SCX_BUG_ON(__scx_name##__attach((__skel)), "Failed to attach skel"); \ ++ __link = bpf_map__attach_struct_ops((__skel)->maps.__ops_name); \ ++ SCX_BUG_ON(!__link, "Failed to attach struct_ops"); \ ++ __link; \ ++}) ++ ++#endif /* __SCX_COMPAT_H */ +diff --git a/tools/sched_ext/include/scx/user_exit_info.h b/tools/sched_ext/include/scx/user_exit_info.h +new file mode 100644 +index 000000000000..891693ee604e +--- /dev/null ++++ b/tools/sched_ext/include/scx/user_exit_info.h +@@ -0,0 +1,111 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Define struct user_exit_info which is shared between BPF and userspace parts ++ * to communicate exit status and other information. ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#ifndef __USER_EXIT_INFO_H ++#define __USER_EXIT_INFO_H ++ ++enum uei_sizes { ++ UEI_REASON_LEN = 128, ++ UEI_MSG_LEN = 1024, ++ UEI_DUMP_DFL_LEN = 32768, ++}; ++ ++struct user_exit_info { ++ int kind; ++ s64 exit_code; ++ char reason[UEI_REASON_LEN]; ++ char msg[UEI_MSG_LEN]; ++}; ++ ++#ifdef __bpf__ ++ ++#include "vmlinux.h" ++#include ++ ++#define UEI_DEFINE(__name) \ ++ char RESIZABLE_ARRAY(data, __name##_dump); \ ++ const volatile u32 __name##_dump_len; \ ++ struct user_exit_info __name SEC(".data") ++ ++#define UEI_RECORD(__uei_name, __ei) ({ \ ++ bpf_probe_read_kernel_str(__uei_name.reason, \ ++ sizeof(__uei_name.reason), (__ei)->reason); \ ++ bpf_probe_read_kernel_str(__uei_name.msg, \ ++ sizeof(__uei_name.msg), (__ei)->msg); \ ++ bpf_probe_read_kernel_str(__uei_name##_dump, \ ++ __uei_name##_dump_len, (__ei)->dump); \ ++ if (bpf_core_field_exists((__ei)->exit_code)) \ ++ __uei_name.exit_code = (__ei)->exit_code; \ ++ /* use __sync to force memory barrier */ \ ++ __sync_val_compare_and_swap(&__uei_name.kind, __uei_name.kind, \ ++ (__ei)->kind); \ ++}) ++ ++#else /* !__bpf__ */ ++ ++#include ++#include ++ ++/* no need to call the following explicitly if SCX_OPS_LOAD() is used */ ++#define UEI_SET_SIZE(__skel, __ops_name, __uei_name) ({ \ ++ u32 __len = (__skel)->struct_ops.__ops_name->exit_dump_len ?: UEI_DUMP_DFL_LEN; \ ++ (__skel)->rodata->__uei_name##_dump_len = __len; \ ++ RESIZE_ARRAY((__skel), data, __uei_name##_dump, __len); \ ++}) ++ ++#define UEI_EXITED(__skel, __uei_name) ({ \ ++ /* use __sync to force memory barrier */ \ ++ __sync_val_compare_and_swap(&(__skel)->data->__uei_name.kind, -1, -1); \ ++}) ++ ++#define UEI_REPORT(__skel, __uei_name) ({ \ ++ struct user_exit_info *__uei = &(__skel)->data->__uei_name; \ ++ char *__uei_dump = (__skel)->data_##__uei_name##_dump->__uei_name##_dump; \ ++ if (__uei_dump[0] != '\0') { \ ++ fputs("\nDEBUG DUMP\n", stderr); \ ++ fputs("================================================================================\n\n", stderr); \ ++ fputs(__uei_dump, stderr); \ ++ fputs("\n================================================================================\n\n", stderr); \ ++ } \ ++ fprintf(stderr, "EXIT: %s", __uei->reason); \ ++ if (__uei->msg[0] != '\0') \ ++ fprintf(stderr, " (%s)", __uei->msg); \ ++ fputs("\n", stderr); \ ++ __uei->exit_code; \ ++}) ++ ++/* ++ * We can't import vmlinux.h while compiling user C code. Let's duplicate ++ * scx_exit_code definition. ++ */ ++enum scx_exit_code { ++ /* Reasons */ ++ SCX_ECODE_RSN_HOTPLUG = 1LLU << 32, ++ ++ /* Actions */ ++ SCX_ECODE_ACT_RESTART = 1LLU << 48, ++}; ++ ++enum uei_ecode_mask { ++ UEI_ECODE_USER_MASK = ((1LLU << 32) - 1), ++ UEI_ECODE_SYS_RSN_MASK = ((1LLU << 16) - 1) << 32, ++ UEI_ECODE_SYS_ACT_MASK = ((1LLU << 16) - 1) << 48, ++}; ++ ++/* ++ * These macro interpret the ecode returned from UEI_REPORT(). ++ */ ++#define UEI_ECODE_USER(__ecode) ((__ecode) & UEI_ECODE_USER_MASK) ++#define UEI_ECODE_SYS_RSN(__ecode) ((__ecode) & UEI_ECODE_SYS_RSN_MASK) ++#define UEI_ECODE_SYS_ACT(__ecode) ((__ecode) & UEI_ECODE_SYS_ACT_MASK) ++ ++#define UEI_ECODE_RESTART(__ecode) (UEI_ECODE_SYS_ACT((__ecode)) == SCX_ECODE_ACT_RESTART) ++ ++#endif /* __bpf__ */ ++#endif /* __USER_EXIT_INFO_H */ +diff --git a/tools/sched_ext/scx_central.bpf.c b/tools/sched_ext/scx_central.bpf.c +new file mode 100644 +index 000000000000..1d8fd570eaa7 +--- /dev/null ++++ b/tools/sched_ext/scx_central.bpf.c +@@ -0,0 +1,361 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A central FIFO sched_ext scheduler which demonstrates the followings: ++ * ++ * a. Making all scheduling decisions from one CPU: ++ * ++ * The central CPU is the only one making scheduling decisions. All other ++ * CPUs kick the central CPU when they run out of tasks to run. ++ * ++ * There is one global BPF queue and the central CPU schedules all CPUs by ++ * dispatching from the global queue to each CPU's local dsq from dispatch(). ++ * This isn't the most straightforward. e.g. It'd be easier to bounce ++ * through per-CPU BPF queues. The current design is chosen to maximally ++ * utilize and verify various SCX mechanisms such as LOCAL_ON dispatching. ++ * ++ * b. Tickless operation ++ * ++ * All tasks are dispatched with the infinite slice which allows stopping the ++ * ticks on CONFIG_NO_HZ_FULL kernels running with the proper nohz_full ++ * parameter. The tickless operation can be observed through ++ * /proc/interrupts. ++ * ++ * Periodic switching is enforced by a periodic timer checking all CPUs and ++ * preempting them as necessary. Unfortunately, BPF timer currently doesn't ++ * have a way to pin to a specific CPU, so the periodic timer isn't pinned to ++ * the central CPU. ++ * ++ * c. Preemption ++ * ++ * Kthreads are unconditionally queued to the head of a matching local dsq ++ * and dispatched with SCX_DSQ_PREEMPT. This ensures that a kthread is always ++ * prioritized over user threads, which is required for ensuring forward ++ * progress as e.g. the periodic timer may run on a ksoftirqd and if the ++ * ksoftirqd gets starved by a user thread, there may not be anything else to ++ * vacate that user thread. ++ * ++ * SCX_KICK_PREEMPT is used to trigger scheduling and CPUs to move to the ++ * next tasks. ++ * ++ * This scheduler is designed to maximize usage of various SCX mechanisms. A ++ * more practical implementation would likely put the scheduling loop outside ++ * the central CPU's dispatch() path and add some form of priority mechanism. ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++enum { ++ FALLBACK_DSQ_ID = 0, ++ MS_TO_NS = 1000LLU * 1000, ++ TIMER_INTERVAL_NS = 1 * MS_TO_NS, ++}; ++ ++const volatile s32 central_cpu; ++const volatile u32 nr_cpu_ids = 1; /* !0 for veristat, set during init */ ++const volatile u64 slice_ns = SCX_SLICE_DFL; ++ ++bool timer_pinned = true; ++u64 nr_total, nr_locals, nr_queued, nr_lost_pids; ++u64 nr_timers, nr_dispatches, nr_mismatches, nr_retries; ++u64 nr_overflows; ++ ++UEI_DEFINE(uei); ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_QUEUE); ++ __uint(max_entries, 4096); ++ __type(value, s32); ++} central_q SEC(".maps"); ++ ++/* can't use percpu map due to bad lookups */ ++bool RESIZABLE_ARRAY(data, cpu_gimme_task); ++u64 RESIZABLE_ARRAY(data, cpu_started_at); ++ ++struct central_timer { ++ struct bpf_timer timer; ++}; ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_ARRAY); ++ __uint(max_entries, 1); ++ __type(key, u32); ++ __type(value, struct central_timer); ++} central_timer SEC(".maps"); ++ ++static bool vtime_before(u64 a, u64 b) ++{ ++ return (s64)(a - b) < 0; ++} ++ ++s32 BPF_STRUCT_OPS(central_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ /* ++ * Steer wakeups to the central CPU as much as possible to avoid ++ * disturbing other CPUs. It's safe to blindly return the central cpu as ++ * select_cpu() is a hint and if @p can't be on it, the kernel will ++ * automatically pick a fallback CPU. ++ */ ++ return central_cpu; ++} ++ ++void BPF_STRUCT_OPS(central_enqueue, struct task_struct *p, u64 enq_flags) ++{ ++ s32 pid = p->pid; ++ ++ __sync_fetch_and_add(&nr_total, 1); ++ ++ /* ++ * Push per-cpu kthreads at the head of local dsq's and preempt the ++ * corresponding CPU. This ensures that e.g. ksoftirqd isn't blocked ++ * behind other threads which is necessary for forward progress ++ * guarantee as we depend on the BPF timer which may run from ksoftirqd. ++ */ ++ if ((p->flags & PF_KTHREAD) && p->nr_cpus_allowed == 1) { ++ __sync_fetch_and_add(&nr_locals, 1); ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL, SCX_SLICE_INF, ++ enq_flags | SCX_ENQ_PREEMPT); ++ return; ++ } ++ ++ if (bpf_map_push_elem(¢ral_q, &pid, 0)) { ++ __sync_fetch_and_add(&nr_overflows, 1); ++ scx_bpf_dispatch(p, FALLBACK_DSQ_ID, SCX_SLICE_INF, enq_flags); ++ return; ++ } ++ ++ __sync_fetch_and_add(&nr_queued, 1); ++ ++ if (!scx_bpf_task_running(p)) ++ scx_bpf_kick_cpu(central_cpu, SCX_KICK_PREEMPT); ++} ++ ++static bool dispatch_to_cpu(s32 cpu) ++{ ++ struct task_struct *p; ++ s32 pid; ++ ++ bpf_repeat(BPF_MAX_LOOPS) { ++ if (bpf_map_pop_elem(¢ral_q, &pid)) ++ break; ++ ++ __sync_fetch_and_sub(&nr_queued, 1); ++ ++ p = bpf_task_from_pid(pid); ++ if (!p) { ++ __sync_fetch_and_add(&nr_lost_pids, 1); ++ continue; ++ } ++ ++ /* ++ * If we can't run the task at the top, do the dumb thing and ++ * bounce it to the fallback dsq. ++ */ ++ if (!bpf_cpumask_test_cpu(cpu, p->cpus_ptr)) { ++ __sync_fetch_and_add(&nr_mismatches, 1); ++ scx_bpf_dispatch(p, FALLBACK_DSQ_ID, SCX_SLICE_INF, 0); ++ bpf_task_release(p); ++ /* ++ * We might run out of dispatch buffer slots if we continue dispatching ++ * to the fallback DSQ, without dispatching to the local DSQ of the ++ * target CPU. In such a case, break the loop now as will fail the ++ * next dispatch operation. ++ */ ++ if (!scx_bpf_dispatch_nr_slots()) ++ break; ++ continue; ++ } ++ ++ /* dispatch to local and mark that @cpu doesn't need more */ ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL_ON | cpu, SCX_SLICE_INF, 0); ++ ++ if (cpu != central_cpu) ++ scx_bpf_kick_cpu(cpu, SCX_KICK_IDLE); ++ ++ bpf_task_release(p); ++ return true; ++ } ++ ++ return false; ++} ++ ++void BPF_STRUCT_OPS(central_dispatch, s32 cpu, struct task_struct *prev) ++{ ++ if (cpu == central_cpu) { ++ /* dispatch for all other CPUs first */ ++ __sync_fetch_and_add(&nr_dispatches, 1); ++ ++ bpf_for(cpu, 0, nr_cpu_ids) { ++ bool *gimme; ++ ++ if (!scx_bpf_dispatch_nr_slots()) ++ break; ++ ++ /* central's gimme is never set */ ++ gimme = ARRAY_ELEM_PTR(cpu_gimme_task, cpu, nr_cpu_ids); ++ if (gimme && !*gimme) ++ continue; ++ ++ if (dispatch_to_cpu(cpu)) ++ *gimme = false; ++ } ++ ++ /* ++ * Retry if we ran out of dispatch buffer slots as we might have ++ * skipped some CPUs and also need to dispatch for self. The ext ++ * core automatically retries if the local dsq is empty but we ++ * can't rely on that as we're dispatching for other CPUs too. ++ * Kick self explicitly to retry. ++ */ ++ if (!scx_bpf_dispatch_nr_slots()) { ++ __sync_fetch_and_add(&nr_retries, 1); ++ scx_bpf_kick_cpu(central_cpu, SCX_KICK_PREEMPT); ++ return; ++ } ++ ++ /* look for a task to run on the central CPU */ ++ if (scx_bpf_consume(FALLBACK_DSQ_ID)) ++ return; ++ dispatch_to_cpu(central_cpu); ++ } else { ++ bool *gimme; ++ ++ if (scx_bpf_consume(FALLBACK_DSQ_ID)) ++ return; ++ ++ gimme = ARRAY_ELEM_PTR(cpu_gimme_task, cpu, nr_cpu_ids); ++ if (gimme) ++ *gimme = true; ++ ++ /* ++ * Force dispatch on the scheduling CPU so that it finds a task ++ * to run for us. ++ */ ++ scx_bpf_kick_cpu(central_cpu, SCX_KICK_PREEMPT); ++ } ++} ++ ++void BPF_STRUCT_OPS(central_running, struct task_struct *p) ++{ ++ s32 cpu = scx_bpf_task_cpu(p); ++ u64 *started_at = ARRAY_ELEM_PTR(cpu_started_at, cpu, nr_cpu_ids); ++ if (started_at) ++ *started_at = bpf_ktime_get_ns() ?: 1; /* 0 indicates idle */ ++} ++ ++void BPF_STRUCT_OPS(central_stopping, struct task_struct *p, bool runnable) ++{ ++ s32 cpu = scx_bpf_task_cpu(p); ++ u64 *started_at = ARRAY_ELEM_PTR(cpu_started_at, cpu, nr_cpu_ids); ++ if (started_at) ++ *started_at = 0; ++} ++ ++static int central_timerfn(void *map, int *key, struct bpf_timer *timer) ++{ ++ u64 now = bpf_ktime_get_ns(); ++ u64 nr_to_kick = nr_queued; ++ s32 i, curr_cpu; ++ ++ curr_cpu = bpf_get_smp_processor_id(); ++ if (timer_pinned && (curr_cpu != central_cpu)) { ++ scx_bpf_error("Central timer ran on CPU %d, not central CPU %d", ++ curr_cpu, central_cpu); ++ return 0; ++ } ++ ++ bpf_for(i, 0, nr_cpu_ids) { ++ s32 cpu = (nr_timers + i) % nr_cpu_ids; ++ u64 *started_at; ++ ++ if (cpu == central_cpu) ++ continue; ++ ++ /* kick iff the current one exhausted its slice */ ++ started_at = ARRAY_ELEM_PTR(cpu_started_at, cpu, nr_cpu_ids); ++ if (started_at && *started_at && ++ vtime_before(now, *started_at + slice_ns)) ++ continue; ++ ++ /* and there's something pending */ ++ if (scx_bpf_dsq_nr_queued(FALLBACK_DSQ_ID) || ++ scx_bpf_dsq_nr_queued(SCX_DSQ_LOCAL_ON | cpu)) ++ ; ++ else if (nr_to_kick) ++ nr_to_kick--; ++ else ++ continue; ++ ++ scx_bpf_kick_cpu(cpu, SCX_KICK_PREEMPT); ++ } ++ ++ bpf_timer_start(timer, TIMER_INTERVAL_NS, BPF_F_TIMER_CPU_PIN); ++ __sync_fetch_and_add(&nr_timers, 1); ++ return 0; ++} ++ ++int BPF_STRUCT_OPS_SLEEPABLE(central_init) ++{ ++ u32 key = 0; ++ struct bpf_timer *timer; ++ int ret; ++ ++ ret = scx_bpf_create_dsq(FALLBACK_DSQ_ID, -1); ++ if (ret) ++ return ret; ++ ++ timer = bpf_map_lookup_elem(¢ral_timer, &key); ++ if (!timer) ++ return -ESRCH; ++ ++ if (bpf_get_smp_processor_id() != central_cpu) { ++ scx_bpf_error("init from non-central CPU"); ++ return -EINVAL; ++ } ++ ++ bpf_timer_init(timer, ¢ral_timer, CLOCK_MONOTONIC); ++ bpf_timer_set_callback(timer, central_timerfn); ++ ++ ret = bpf_timer_start(timer, TIMER_INTERVAL_NS, BPF_F_TIMER_CPU_PIN); ++ /* ++ * BPF_F_TIMER_CPU_PIN is pretty new (>=6.7). If we're running in a ++ * kernel which doesn't have it, bpf_timer_start() will return -EINVAL. ++ * Retry without the PIN. This would be the perfect use case for ++ * bpf_core_enum_value_exists() but the enum type doesn't have a name ++ * and can't be used with bpf_core_enum_value_exists(). Oh well... ++ */ ++ if (ret == -EINVAL) { ++ timer_pinned = false; ++ ret = bpf_timer_start(timer, TIMER_INTERVAL_NS, 0); ++ } ++ if (ret) ++ scx_bpf_error("bpf_timer_start failed (%d)", ret); ++ return ret; ++} ++ ++void BPF_STRUCT_OPS(central_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SCX_OPS_DEFINE(central_ops, ++ /* ++ * We are offloading all scheduling decisions to the central CPU ++ * and thus being the last task on a given CPU doesn't mean ++ * anything special. Enqueue the last tasks like any other tasks. ++ */ ++ .flags = SCX_OPS_ENQ_LAST, ++ ++ .select_cpu = (void *)central_select_cpu, ++ .enqueue = (void *)central_enqueue, ++ .dispatch = (void *)central_dispatch, ++ .running = (void *)central_running, ++ .stopping = (void *)central_stopping, ++ .init = (void *)central_init, ++ .exit = (void *)central_exit, ++ .name = "central"); +diff --git a/tools/sched_ext/scx_central.c b/tools/sched_ext/scx_central.c +new file mode 100644 +index 000000000000..21deea320bd7 +--- /dev/null ++++ b/tools/sched_ext/scx_central.c +@@ -0,0 +1,135 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#define _GNU_SOURCE ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include "scx_central.bpf.skel.h" ++ ++const char help_fmt[] = ++"A central FIFO sched_ext scheduler.\n" ++"\n" ++"See the top-level comment in .bpf.c for more details.\n" ++"\n" ++"Usage: %s [-s SLICE_US] [-c CPU]\n" ++"\n" ++" -s SLICE_US Override slice duration\n" ++" -c CPU Override the central CPU (default: 0)\n" ++" -v Print libbpf debug messages\n" ++" -h Display this help and exit\n"; ++ ++static bool verbose; ++static volatile int exit_req; ++ ++static int libbpf_print_fn(enum libbpf_print_level level, const char *format, va_list args) ++{ ++ if (level == LIBBPF_DEBUG && !verbose) ++ return 0; ++ return vfprintf(stderr, format, args); ++} ++ ++static void sigint_handler(int dummy) ++{ ++ exit_req = 1; ++} ++ ++int main(int argc, char **argv) ++{ ++ struct scx_central *skel; ++ struct bpf_link *link; ++ __u64 seq = 0, ecode; ++ __s32 opt; ++ cpu_set_t *cpuset; ++ ++ libbpf_set_print(libbpf_print_fn); ++ signal(SIGINT, sigint_handler); ++ signal(SIGTERM, sigint_handler); ++restart: ++ skel = SCX_OPS_OPEN(central_ops, scx_central); ++ ++ skel->rodata->central_cpu = 0; ++ skel->rodata->nr_cpu_ids = libbpf_num_possible_cpus(); ++ ++ while ((opt = getopt(argc, argv, "s:c:pvh")) != -1) { ++ switch (opt) { ++ case 's': ++ skel->rodata->slice_ns = strtoull(optarg, NULL, 0) * 1000; ++ break; ++ case 'c': ++ skel->rodata->central_cpu = strtoul(optarg, NULL, 0); ++ break; ++ case 'v': ++ verbose = true; ++ break; ++ default: ++ fprintf(stderr, help_fmt, basename(argv[0])); ++ return opt != 'h'; ++ } ++ } ++ ++ /* Resize arrays so their element count is equal to cpu count. */ ++ RESIZE_ARRAY(skel, data, cpu_gimme_task, skel->rodata->nr_cpu_ids); ++ RESIZE_ARRAY(skel, data, cpu_started_at, skel->rodata->nr_cpu_ids); ++ ++ SCX_OPS_LOAD(skel, central_ops, scx_central, uei); ++ ++ /* ++ * Affinitize the loading thread to the central CPU, as: ++ * - That's where the BPF timer is first invoked in the BPF program. ++ * - We probably don't want this user space component to take up a core ++ * from a task that would benefit from avoiding preemption on one of ++ * the tickless cores. ++ * ++ * Until BPF supports pinning the timer, it's not guaranteed that it ++ * will always be invoked on the central CPU. In practice, this ++ * suffices the majority of the time. ++ */ ++ cpuset = CPU_ALLOC(skel->rodata->nr_cpu_ids); ++ SCX_BUG_ON(!cpuset, "Failed to allocate cpuset"); ++ CPU_ZERO(cpuset); ++ CPU_SET(skel->rodata->central_cpu, cpuset); ++ SCX_BUG_ON(sched_setaffinity(0, sizeof(cpuset), cpuset), ++ "Failed to affinitize to central CPU %d (max %d)", ++ skel->rodata->central_cpu, skel->rodata->nr_cpu_ids - 1); ++ CPU_FREE(cpuset); ++ ++ link = SCX_OPS_ATTACH(skel, central_ops, scx_central); ++ ++ if (!skel->data->timer_pinned) ++ printf("WARNING : BPF_F_TIMER_CPU_PIN not available, timer not pinned to central\n"); ++ ++ while (!exit_req && !UEI_EXITED(skel, uei)) { ++ printf("[SEQ %llu]\n", seq++); ++ printf("total :%10" PRIu64 " local:%10" PRIu64 " queued:%10" PRIu64 " lost:%10" PRIu64 "\n", ++ skel->bss->nr_total, ++ skel->bss->nr_locals, ++ skel->bss->nr_queued, ++ skel->bss->nr_lost_pids); ++ printf("timer :%10" PRIu64 " dispatch:%10" PRIu64 " mismatch:%10" PRIu64 " retry:%10" PRIu64 "\n", ++ skel->bss->nr_timers, ++ skel->bss->nr_dispatches, ++ skel->bss->nr_mismatches, ++ skel->bss->nr_retries); ++ printf("overflow:%10" PRIu64 "\n", ++ skel->bss->nr_overflows); ++ fflush(stdout); ++ sleep(1); ++ } ++ ++ bpf_link__destroy(link); ++ ecode = UEI_REPORT(skel, uei); ++ scx_central__destroy(skel); ++ ++ if (UEI_ECODE_RESTART(ecode)) ++ goto restart; ++ return 0; ++} +diff --git a/tools/sched_ext/scx_qmap.bpf.c b/tools/sched_ext/scx_qmap.bpf.c +new file mode 100644 +index 000000000000..892278f12dce +--- /dev/null ++++ b/tools/sched_ext/scx_qmap.bpf.c +@@ -0,0 +1,706 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A simple five-level FIFO queue scheduler. ++ * ++ * There are five FIFOs implemented using BPF_MAP_TYPE_QUEUE. A task gets ++ * assigned to one depending on its compound weight. Each CPU round robins ++ * through the FIFOs and dispatches more from FIFOs with higher indices - 1 from ++ * queue0, 2 from queue1, 4 from queue2 and so on. ++ * ++ * This scheduler demonstrates: ++ * ++ * - BPF-side queueing using PIDs. ++ * - Sleepable per-task storage allocation using ops.prep_enable(). ++ * - Using ops.cpu_release() to handle a higher priority scheduling class taking ++ * the CPU away. ++ * - Core-sched support. ++ * ++ * This scheduler is primarily for demonstration and testing of sched_ext ++ * features and unlikely to be useful for actual workloads. ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#include ++ ++enum consts { ++ ONE_SEC_IN_NS = 1000000000, ++ SHARED_DSQ = 0, ++}; ++ ++char _license[] SEC("license") = "GPL"; ++ ++const volatile u64 slice_ns = SCX_SLICE_DFL; ++const volatile u32 stall_user_nth; ++const volatile u32 stall_kernel_nth; ++const volatile u32 dsp_inf_loop_after; ++const volatile u32 dsp_batch; ++const volatile bool print_shared_dsq; ++const volatile s32 disallow_tgid; ++const volatile bool suppress_dump; ++ ++u32 test_error_cnt; ++ ++UEI_DEFINE(uei); ++ ++struct qmap { ++ __uint(type, BPF_MAP_TYPE_QUEUE); ++ __uint(max_entries, 4096); ++ __type(value, u32); ++} queue0 SEC(".maps"), ++ queue1 SEC(".maps"), ++ queue2 SEC(".maps"), ++ queue3 SEC(".maps"), ++ queue4 SEC(".maps"); ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_ARRAY_OF_MAPS); ++ __uint(max_entries, 5); ++ __type(key, int); ++ __array(values, struct qmap); ++} queue_arr SEC(".maps") = { ++ .values = { ++ [0] = &queue0, ++ [1] = &queue1, ++ [2] = &queue2, ++ [3] = &queue3, ++ [4] = &queue4, ++ }, ++}; ++ ++/* ++ * If enabled, CPU performance target is set according to the queue index ++ * according to the following table. ++ */ ++static const u32 qidx_to_cpuperf_target[] = { ++ [0] = SCX_CPUPERF_ONE * 0 / 4, ++ [1] = SCX_CPUPERF_ONE * 1 / 4, ++ [2] = SCX_CPUPERF_ONE * 2 / 4, ++ [3] = SCX_CPUPERF_ONE * 3 / 4, ++ [4] = SCX_CPUPERF_ONE * 4 / 4, ++}; ++ ++/* ++ * Per-queue sequence numbers to implement core-sched ordering. ++ * ++ * Tail seq is assigned to each queued task and incremented. Head seq tracks the ++ * sequence number of the latest dispatched task. The distance between the a ++ * task's seq and the associated queue's head seq is called the queue distance ++ * and used when comparing two tasks for ordering. See qmap_core_sched_before(). ++ */ ++static u64 core_sched_head_seqs[5]; ++static u64 core_sched_tail_seqs[5]; ++ ++/* Per-task scheduling context */ ++struct task_ctx { ++ bool force_local; /* Dispatch directly to local_dsq */ ++ u64 core_sched_seq; ++}; ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_TASK_STORAGE); ++ __uint(map_flags, BPF_F_NO_PREALLOC); ++ __type(key, int); ++ __type(value, struct task_ctx); ++} task_ctx_stor SEC(".maps"); ++ ++struct cpu_ctx { ++ u64 dsp_idx; /* dispatch index */ ++ u64 dsp_cnt; /* remaining count */ ++ u32 avg_weight; ++ u32 cpuperf_target; ++}; ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); ++ __uint(max_entries, 1); ++ __type(key, u32); ++ __type(value, struct cpu_ctx); ++} cpu_ctx_stor SEC(".maps"); ++ ++/* Statistics */ ++u64 nr_enqueued, nr_dispatched, nr_reenqueued, nr_dequeued, nr_ddsp_from_enq; ++u64 nr_core_sched_execed; ++u32 cpuperf_min, cpuperf_avg, cpuperf_max; ++u32 cpuperf_target_min, cpuperf_target_avg, cpuperf_target_max; ++ ++static s32 pick_direct_dispatch_cpu(struct task_struct *p, s32 prev_cpu) ++{ ++ s32 cpu; ++ ++ if (p->nr_cpus_allowed == 1 || ++ scx_bpf_test_and_clear_cpu_idle(prev_cpu)) ++ return prev_cpu; ++ ++ cpu = scx_bpf_pick_idle_cpu(p->cpus_ptr, 0); ++ if (cpu >= 0) ++ return cpu; ++ ++ return -1; ++} ++ ++s32 BPF_STRUCT_OPS(qmap_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ struct task_ctx *tctx; ++ s32 cpu; ++ ++ tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0); ++ if (!tctx) { ++ scx_bpf_error("task_ctx lookup failed"); ++ return -ESRCH; ++ } ++ ++ cpu = pick_direct_dispatch_cpu(p, prev_cpu); ++ ++ if (cpu >= 0) { ++ tctx->force_local = true; ++ return cpu; ++ } else { ++ return prev_cpu; ++ } ++} ++ ++static int weight_to_idx(u32 weight) ++{ ++ /* Coarsely map the compound weight to a FIFO. */ ++ if (weight <= 25) ++ return 0; ++ else if (weight <= 50) ++ return 1; ++ else if (weight < 200) ++ return 2; ++ else if (weight < 400) ++ return 3; ++ else ++ return 4; ++} ++ ++void BPF_STRUCT_OPS(qmap_enqueue, struct task_struct *p, u64 enq_flags) ++{ ++ static u32 user_cnt, kernel_cnt; ++ struct task_ctx *tctx; ++ u32 pid = p->pid; ++ int idx = weight_to_idx(p->scx.weight); ++ void *ring; ++ s32 cpu; ++ ++ if (p->flags & PF_KTHREAD) { ++ if (stall_kernel_nth && !(++kernel_cnt % stall_kernel_nth)) ++ return; ++ } else { ++ if (stall_user_nth && !(++user_cnt % stall_user_nth)) ++ return; ++ } ++ ++ if (test_error_cnt && !--test_error_cnt) ++ scx_bpf_error("test triggering error"); ++ ++ tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0); ++ if (!tctx) { ++ scx_bpf_error("task_ctx lookup failed"); ++ return; ++ } ++ ++ /* ++ * All enqueued tasks must have their core_sched_seq updated for correct ++ * core-sched ordering, which is why %SCX_OPS_ENQ_LAST is specified in ++ * qmap_ops.flags. ++ */ ++ tctx->core_sched_seq = core_sched_tail_seqs[idx]++; ++ ++ /* ++ * If qmap_select_cpu() is telling us to or this is the last runnable ++ * task on the CPU, enqueue locally. ++ */ ++ if (tctx->force_local || (enq_flags & SCX_ENQ_LAST)) { ++ tctx->force_local = false; ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL, slice_ns, enq_flags); ++ return; ++ } ++ ++ /* if !WAKEUP, select_cpu() wasn't called, try direct dispatch */ ++ if (!(enq_flags & SCX_ENQ_WAKEUP) && ++ (cpu = pick_direct_dispatch_cpu(p, scx_bpf_task_cpu(p))) >= 0) { ++ __sync_fetch_and_add(&nr_ddsp_from_enq, 1); ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL_ON | cpu, slice_ns, enq_flags); ++ return; ++ } ++ ++ /* ++ * If the task was re-enqueued due to the CPU being preempted by a ++ * higher priority scheduling class, just re-enqueue the task directly ++ * on the global DSQ. As we want another CPU to pick it up, find and ++ * kick an idle CPU. ++ */ ++ if (enq_flags & SCX_ENQ_REENQ) { ++ s32 cpu; ++ ++ scx_bpf_dispatch(p, SHARED_DSQ, 0, enq_flags); ++ cpu = scx_bpf_pick_idle_cpu(p->cpus_ptr, 0); ++ if (cpu >= 0) ++ scx_bpf_kick_cpu(cpu, SCX_KICK_IDLE); ++ return; ++ } ++ ++ ring = bpf_map_lookup_elem(&queue_arr, &idx); ++ if (!ring) { ++ scx_bpf_error("failed to find ring %d", idx); ++ return; ++ } ++ ++ /* Queue on the selected FIFO. If the FIFO overflows, punt to global. */ ++ if (bpf_map_push_elem(ring, &pid, 0)) { ++ scx_bpf_dispatch(p, SHARED_DSQ, slice_ns, enq_flags); ++ return; ++ } ++ ++ __sync_fetch_and_add(&nr_enqueued, 1); ++} ++ ++/* ++ * The BPF queue map doesn't support removal and sched_ext can handle spurious ++ * dispatches. qmap_dequeue() is only used to collect statistics. ++ */ ++void BPF_STRUCT_OPS(qmap_dequeue, struct task_struct *p, u64 deq_flags) ++{ ++ __sync_fetch_and_add(&nr_dequeued, 1); ++ if (deq_flags & SCX_DEQ_CORE_SCHED_EXEC) ++ __sync_fetch_and_add(&nr_core_sched_execed, 1); ++} ++ ++static void update_core_sched_head_seq(struct task_struct *p) ++{ ++ struct task_ctx *tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0); ++ int idx = weight_to_idx(p->scx.weight); ++ ++ if (tctx) ++ core_sched_head_seqs[idx] = tctx->core_sched_seq; ++ else ++ scx_bpf_error("task_ctx lookup failed"); ++} ++ ++void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev) ++{ ++ struct task_struct *p; ++ struct cpu_ctx *cpuc; ++ u32 zero = 0, batch = dsp_batch ?: 1; ++ void *fifo; ++ s32 i, pid; ++ ++ if (scx_bpf_consume(SHARED_DSQ)) ++ return; ++ ++ if (dsp_inf_loop_after && nr_dispatched > dsp_inf_loop_after) { ++ /* ++ * PID 2 should be kthreadd which should mostly be idle and off ++ * the scheduler. Let's keep dispatching it to force the kernel ++ * to call this function over and over again. ++ */ ++ p = bpf_task_from_pid(2); ++ if (p) { ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL, slice_ns, 0); ++ bpf_task_release(p); ++ return; ++ } ++ } ++ ++ if (!(cpuc = bpf_map_lookup_elem(&cpu_ctx_stor, &zero))) { ++ scx_bpf_error("failed to look up cpu_ctx"); ++ return; ++ } ++ ++ for (i = 0; i < 5; i++) { ++ /* Advance the dispatch cursor and pick the fifo. */ ++ if (!cpuc->dsp_cnt) { ++ cpuc->dsp_idx = (cpuc->dsp_idx + 1) % 5; ++ cpuc->dsp_cnt = 1 << cpuc->dsp_idx; ++ } ++ ++ fifo = bpf_map_lookup_elem(&queue_arr, &cpuc->dsp_idx); ++ if (!fifo) { ++ scx_bpf_error("failed to find ring %llu", cpuc->dsp_idx); ++ return; ++ } ++ ++ /* Dispatch or advance. */ ++ bpf_repeat(BPF_MAX_LOOPS) { ++ if (bpf_map_pop_elem(fifo, &pid)) ++ break; ++ ++ p = bpf_task_from_pid(pid); ++ if (!p) ++ continue; ++ ++ update_core_sched_head_seq(p); ++ __sync_fetch_and_add(&nr_dispatched, 1); ++ scx_bpf_dispatch(p, SHARED_DSQ, slice_ns, 0); ++ bpf_task_release(p); ++ batch--; ++ cpuc->dsp_cnt--; ++ if (!batch || !scx_bpf_dispatch_nr_slots()) { ++ scx_bpf_consume(SHARED_DSQ); ++ return; ++ } ++ if (!cpuc->dsp_cnt) ++ break; ++ } ++ ++ cpuc->dsp_cnt = 0; ++ } ++} ++ ++void BPF_STRUCT_OPS(qmap_tick, struct task_struct *p) ++{ ++ struct cpu_ctx *cpuc; ++ u32 zero = 0; ++ int idx; ++ ++ if (!(cpuc = bpf_map_lookup_elem(&cpu_ctx_stor, &zero))) { ++ scx_bpf_error("failed to look up cpu_ctx"); ++ return; ++ } ++ ++ /* ++ * Use the running avg of weights to select the target cpuperf level. ++ * This is a demonstration of the cpuperf feature rather than a ++ * practical strategy to regulate CPU frequency. ++ */ ++ cpuc->avg_weight = cpuc->avg_weight * 3 / 4 + p->scx.weight / 4; ++ idx = weight_to_idx(cpuc->avg_weight); ++ cpuc->cpuperf_target = qidx_to_cpuperf_target[idx]; ++ ++ scx_bpf_cpuperf_set(scx_bpf_task_cpu(p), cpuc->cpuperf_target); ++} ++ ++/* ++ * The distance from the head of the queue scaled by the weight of the queue. ++ * The lower the number, the older the task and the higher the priority. ++ */ ++static s64 task_qdist(struct task_struct *p) ++{ ++ int idx = weight_to_idx(p->scx.weight); ++ struct task_ctx *tctx; ++ s64 qdist; ++ ++ tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0); ++ if (!tctx) { ++ scx_bpf_error("task_ctx lookup failed"); ++ return 0; ++ } ++ ++ qdist = tctx->core_sched_seq - core_sched_head_seqs[idx]; ++ ++ /* ++ * As queue index increments, the priority doubles. The queue w/ index 3 ++ * is dispatched twice more frequently than 2. Reflect the difference by ++ * scaling qdists accordingly. Note that the shift amount needs to be ++ * flipped depending on the sign to avoid flipping priority direction. ++ */ ++ if (qdist >= 0) ++ return qdist << (4 - idx); ++ else ++ return qdist << idx; ++} ++ ++/* ++ * This is called to determine the task ordering when core-sched is picking ++ * tasks to execute on SMT siblings and should encode about the same ordering as ++ * the regular scheduling path. Use the priority-scaled distances from the head ++ * of the queues to compare the two tasks which should be consistent with the ++ * dispatch path behavior. ++ */ ++bool BPF_STRUCT_OPS(qmap_core_sched_before, ++ struct task_struct *a, struct task_struct *b) ++{ ++ return task_qdist(a) > task_qdist(b); ++} ++ ++void BPF_STRUCT_OPS(qmap_cpu_release, s32 cpu, struct scx_cpu_release_args *args) ++{ ++ u32 cnt; ++ ++ /* ++ * Called when @cpu is taken by a higher priority scheduling class. This ++ * makes @cpu no longer available for executing sched_ext tasks. As we ++ * don't want the tasks in @cpu's local dsq to sit there until @cpu ++ * becomes available again, re-enqueue them into the global dsq. See ++ * %SCX_ENQ_REENQ handling in qmap_enqueue(). ++ */ ++ cnt = scx_bpf_reenqueue_local(); ++ if (cnt) ++ __sync_fetch_and_add(&nr_reenqueued, cnt); ++} ++ ++s32 BPF_STRUCT_OPS(qmap_init_task, struct task_struct *p, ++ struct scx_init_task_args *args) ++{ ++ if (p->tgid == disallow_tgid) ++ p->scx.disallow = true; ++ ++ /* ++ * @p is new. Let's ensure that its task_ctx is available. We can sleep ++ * in this function and the following will automatically use GFP_KERNEL. ++ */ ++ if (bpf_task_storage_get(&task_ctx_stor, p, 0, ++ BPF_LOCAL_STORAGE_GET_F_CREATE)) ++ return 0; ++ else ++ return -ENOMEM; ++} ++ ++void BPF_STRUCT_OPS(qmap_dump, struct scx_dump_ctx *dctx) ++{ ++ s32 i, pid; ++ ++ if (suppress_dump) ++ return; ++ ++ bpf_for(i, 0, 5) { ++ void *fifo; ++ ++ if (!(fifo = bpf_map_lookup_elem(&queue_arr, &i))) ++ return; ++ ++ scx_bpf_dump("QMAP FIFO[%d]:", i); ++ bpf_repeat(4096) { ++ if (bpf_map_pop_elem(fifo, &pid)) ++ break; ++ scx_bpf_dump(" %d", pid); ++ } ++ scx_bpf_dump("\n"); ++ } ++} ++ ++void BPF_STRUCT_OPS(qmap_dump_cpu, struct scx_dump_ctx *dctx, s32 cpu, bool idle) ++{ ++ u32 zero = 0; ++ struct cpu_ctx *cpuc; ++ ++ if (suppress_dump || idle) ++ return; ++ if (!(cpuc = bpf_map_lookup_percpu_elem(&cpu_ctx_stor, &zero, cpu))) ++ return; ++ ++ scx_bpf_dump("QMAP: dsp_idx=%llu dsp_cnt=%llu avg_weight=%u cpuperf_target=%u", ++ cpuc->dsp_idx, cpuc->dsp_cnt, cpuc->avg_weight, ++ cpuc->cpuperf_target); ++} ++ ++void BPF_STRUCT_OPS(qmap_dump_task, struct scx_dump_ctx *dctx, struct task_struct *p) ++{ ++ struct task_ctx *taskc; ++ ++ if (suppress_dump) ++ return; ++ if (!(taskc = bpf_task_storage_get(&task_ctx_stor, p, 0, 0))) ++ return; ++ ++ scx_bpf_dump("QMAP: force_local=%d core_sched_seq=%llu", ++ taskc->force_local, taskc->core_sched_seq); ++} ++ ++/* ++ * Print out the online and possible CPU map using bpf_printk() as a ++ * demonstration of using the cpumask kfuncs and ops.cpu_on/offline(). ++ */ ++static void print_cpus(void) ++{ ++ const struct cpumask *possible, *online; ++ s32 cpu; ++ char buf[128] = "", *p; ++ int idx; ++ ++ possible = scx_bpf_get_possible_cpumask(); ++ online = scx_bpf_get_online_cpumask(); ++ ++ idx = 0; ++ bpf_for(cpu, 0, scx_bpf_nr_cpu_ids()) { ++ if (!(p = MEMBER_VPTR(buf, [idx++]))) ++ break; ++ if (bpf_cpumask_test_cpu(cpu, online)) ++ *p++ = 'O'; ++ else if (bpf_cpumask_test_cpu(cpu, possible)) ++ *p++ = 'X'; ++ else ++ *p++ = ' '; ++ ++ if ((cpu & 7) == 7) { ++ if (!(p = MEMBER_VPTR(buf, [idx++]))) ++ break; ++ *p++ = '|'; ++ } ++ } ++ buf[sizeof(buf) - 1] = '\0'; ++ ++ scx_bpf_put_cpumask(online); ++ scx_bpf_put_cpumask(possible); ++ ++ bpf_printk("CPUS: |%s", buf); ++} ++ ++void BPF_STRUCT_OPS(qmap_cpu_online, s32 cpu) ++{ ++ bpf_printk("CPU %d coming online", cpu); ++ /* @cpu is already online at this point */ ++ print_cpus(); ++} ++ ++void BPF_STRUCT_OPS(qmap_cpu_offline, s32 cpu) ++{ ++ bpf_printk("CPU %d going offline", cpu); ++ /* @cpu is still online at this point */ ++ print_cpus(); ++} ++ ++struct monitor_timer { ++ struct bpf_timer timer; ++}; ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_ARRAY); ++ __uint(max_entries, 1); ++ __type(key, u32); ++ __type(value, struct monitor_timer); ++} monitor_timer SEC(".maps"); ++ ++/* ++ * Print out the min, avg and max performance levels of CPUs every second to ++ * demonstrate the cpuperf interface. ++ */ ++static void monitor_cpuperf(void) ++{ ++ u32 zero = 0, nr_cpu_ids; ++ u64 cap_sum = 0, cur_sum = 0, cur_min = SCX_CPUPERF_ONE, cur_max = 0; ++ u64 target_sum = 0, target_min = SCX_CPUPERF_ONE, target_max = 0; ++ const struct cpumask *online; ++ int i, nr_online_cpus = 0; ++ ++ nr_cpu_ids = scx_bpf_nr_cpu_ids(); ++ online = scx_bpf_get_online_cpumask(); ++ ++ bpf_for(i, 0, nr_cpu_ids) { ++ struct cpu_ctx *cpuc; ++ u32 cap, cur; ++ ++ if (!bpf_cpumask_test_cpu(i, online)) ++ continue; ++ nr_online_cpus++; ++ ++ /* collect the capacity and current cpuperf */ ++ cap = scx_bpf_cpuperf_cap(i); ++ cur = scx_bpf_cpuperf_cur(i); ++ ++ cur_min = cur < cur_min ? cur : cur_min; ++ cur_max = cur > cur_max ? cur : cur_max; ++ ++ /* ++ * $cur is relative to $cap. Scale it down accordingly so that ++ * it's in the same scale as other CPUs and $cur_sum/$cap_sum ++ * makes sense. ++ */ ++ cur_sum += cur * cap / SCX_CPUPERF_ONE; ++ cap_sum += cap; ++ ++ if (!(cpuc = bpf_map_lookup_percpu_elem(&cpu_ctx_stor, &zero, i))) { ++ scx_bpf_error("failed to look up cpu_ctx"); ++ goto out; ++ } ++ ++ /* collect target */ ++ cur = cpuc->cpuperf_target; ++ target_sum += cur; ++ target_min = cur < target_min ? cur : target_min; ++ target_max = cur > target_max ? cur : target_max; ++ } ++ ++ cpuperf_min = cur_min; ++ cpuperf_avg = cur_sum * SCX_CPUPERF_ONE / cap_sum; ++ cpuperf_max = cur_max; ++ ++ cpuperf_target_min = target_min; ++ cpuperf_target_avg = target_sum / nr_online_cpus; ++ cpuperf_target_max = target_max; ++out: ++ scx_bpf_put_cpumask(online); ++} ++ ++/* ++ * Dump the currently queued tasks in the shared DSQ to demonstrate the usage of ++ * scx_bpf_dsq_nr_queued() and DSQ iterator. Raise the dispatch batch count to ++ * see meaningful dumps in the trace pipe. ++ */ ++static void dump_shared_dsq(void) ++{ ++ struct task_struct *p; ++ s32 nr; ++ ++ if (!(nr = scx_bpf_dsq_nr_queued(SHARED_DSQ))) ++ return; ++ ++ bpf_printk("Dumping %d tasks in SHARED_DSQ in reverse order", nr); ++ ++ bpf_rcu_read_lock(); ++ bpf_for_each(scx_dsq, p, SHARED_DSQ, SCX_DSQ_ITER_REV) ++ bpf_printk("%s[%d]", p->comm, p->pid); ++ bpf_rcu_read_unlock(); ++} ++ ++static int monitor_timerfn(void *map, int *key, struct bpf_timer *timer) ++{ ++ monitor_cpuperf(); ++ ++ if (print_shared_dsq) ++ dump_shared_dsq(); ++ ++ bpf_timer_start(timer, ONE_SEC_IN_NS, 0); ++ return 0; ++} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(qmap_init) ++{ ++ u32 key = 0; ++ struct bpf_timer *timer; ++ s32 ret; ++ ++ print_cpus(); ++ ++ ret = scx_bpf_create_dsq(SHARED_DSQ, -1); ++ if (ret) ++ return ret; ++ ++ timer = bpf_map_lookup_elem(&monitor_timer, &key); ++ if (!timer) ++ return -ESRCH; ++ ++ bpf_timer_init(timer, &monitor_timer, CLOCK_MONOTONIC); ++ bpf_timer_set_callback(timer, monitor_timerfn); ++ ++ return bpf_timer_start(timer, ONE_SEC_IN_NS, 0); ++} ++ ++void BPF_STRUCT_OPS(qmap_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SCX_OPS_DEFINE(qmap_ops, ++ .select_cpu = (void *)qmap_select_cpu, ++ .enqueue = (void *)qmap_enqueue, ++ .dequeue = (void *)qmap_dequeue, ++ .dispatch = (void *)qmap_dispatch, ++ .tick = (void *)qmap_tick, ++ .core_sched_before = (void *)qmap_core_sched_before, ++ .cpu_release = (void *)qmap_cpu_release, ++ .init_task = (void *)qmap_init_task, ++ .dump = (void *)qmap_dump, ++ .dump_cpu = (void *)qmap_dump_cpu, ++ .dump_task = (void *)qmap_dump_task, ++ .cpu_online = (void *)qmap_cpu_online, ++ .cpu_offline = (void *)qmap_cpu_offline, ++ .init = (void *)qmap_init, ++ .exit = (void *)qmap_exit, ++ .flags = SCX_OPS_ENQ_LAST, ++ .timeout_ms = 5000U, ++ .name = "qmap"); +diff --git a/tools/sched_ext/scx_qmap.c b/tools/sched_ext/scx_qmap.c +new file mode 100644 +index 000000000000..c9ca30d62b2b +--- /dev/null ++++ b/tools/sched_ext/scx_qmap.c +@@ -0,0 +1,144 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include "scx_qmap.bpf.skel.h" ++ ++const char help_fmt[] = ++"A simple five-level FIFO queue sched_ext scheduler.\n" ++"\n" ++"See the top-level comment in .bpf.c for more details.\n" ++"\n" ++"Usage: %s [-s SLICE_US] [-e COUNT] [-t COUNT] [-T COUNT] [-l COUNT] [-b COUNT]\n" ++" [-P] [-d PID] [-D LEN] [-p] [-v]\n" ++"\n" ++" -s SLICE_US Override slice duration\n" ++" -e COUNT Trigger scx_bpf_error() after COUNT enqueues\n" ++" -t COUNT Stall every COUNT'th user thread\n" ++" -T COUNT Stall every COUNT'th kernel thread\n" ++" -l COUNT Trigger dispatch infinite looping after COUNT dispatches\n" ++" -b COUNT Dispatch upto COUNT tasks together\n" ++" -P Print out DSQ content to trace_pipe every second, use with -b\n" ++" -d PID Disallow a process from switching into SCHED_EXT (-1 for self)\n" ++" -D LEN Set scx_exit_info.dump buffer length\n" ++" -S Suppress qmap-specific debug dump\n" ++" -p Switch only tasks on SCHED_EXT policy instead of all\n" ++" -v Print libbpf debug messages\n" ++" -h Display this help and exit\n"; ++ ++static bool verbose; ++static volatile int exit_req; ++ ++static int libbpf_print_fn(enum libbpf_print_level level, const char *format, va_list args) ++{ ++ if (level == LIBBPF_DEBUG && !verbose) ++ return 0; ++ return vfprintf(stderr, format, args); ++} ++ ++static void sigint_handler(int dummy) ++{ ++ exit_req = 1; ++} ++ ++int main(int argc, char **argv) ++{ ++ struct scx_qmap *skel; ++ struct bpf_link *link; ++ int opt; ++ ++ libbpf_set_print(libbpf_print_fn); ++ signal(SIGINT, sigint_handler); ++ signal(SIGTERM, sigint_handler); ++ ++ skel = SCX_OPS_OPEN(qmap_ops, scx_qmap); ++ ++ while ((opt = getopt(argc, argv, "s:e:t:T:l:b:Pd:D:Spvh")) != -1) { ++ switch (opt) { ++ case 's': ++ skel->rodata->slice_ns = strtoull(optarg, NULL, 0) * 1000; ++ break; ++ case 'e': ++ skel->bss->test_error_cnt = strtoul(optarg, NULL, 0); ++ break; ++ case 't': ++ skel->rodata->stall_user_nth = strtoul(optarg, NULL, 0); ++ break; ++ case 'T': ++ skel->rodata->stall_kernel_nth = strtoul(optarg, NULL, 0); ++ break; ++ case 'l': ++ skel->rodata->dsp_inf_loop_after = strtoul(optarg, NULL, 0); ++ break; ++ case 'b': ++ skel->rodata->dsp_batch = strtoul(optarg, NULL, 0); ++ break; ++ case 'P': ++ skel->rodata->print_shared_dsq = true; ++ break; ++ case 'd': ++ skel->rodata->disallow_tgid = strtol(optarg, NULL, 0); ++ if (skel->rodata->disallow_tgid < 0) ++ skel->rodata->disallow_tgid = getpid(); ++ break; ++ case 'D': ++ skel->struct_ops.qmap_ops->exit_dump_len = strtoul(optarg, NULL, 0); ++ break; ++ case 'S': ++ skel->rodata->suppress_dump = true; ++ break; ++ case 'p': ++ skel->struct_ops.qmap_ops->flags |= SCX_OPS_SWITCH_PARTIAL; ++ break; ++ case 'v': ++ verbose = true; ++ break; ++ default: ++ fprintf(stderr, help_fmt, basename(argv[0])); ++ return opt != 'h'; ++ } ++ } ++ ++ SCX_OPS_LOAD(skel, qmap_ops, scx_qmap, uei); ++ link = SCX_OPS_ATTACH(skel, qmap_ops, scx_qmap); ++ ++ while (!exit_req && !UEI_EXITED(skel, uei)) { ++ long nr_enqueued = skel->bss->nr_enqueued; ++ long nr_dispatched = skel->bss->nr_dispatched; ++ ++ printf("stats : enq=%lu dsp=%lu delta=%ld reenq=%"PRIu64" deq=%"PRIu64" core=%"PRIu64" enq_ddsp=%"PRIu64"\n", ++ nr_enqueued, nr_dispatched, nr_enqueued - nr_dispatched, ++ skel->bss->nr_reenqueued, skel->bss->nr_dequeued, ++ skel->bss->nr_core_sched_execed, ++ skel->bss->nr_ddsp_from_enq); ++ if (__COMPAT_has_ksym("scx_bpf_cpuperf_cur")) ++ printf("cpuperf: cur min/avg/max=%u/%u/%u target min/avg/max=%u/%u/%u\n", ++ skel->bss->cpuperf_min, ++ skel->bss->cpuperf_avg, ++ skel->bss->cpuperf_max, ++ skel->bss->cpuperf_target_min, ++ skel->bss->cpuperf_target_avg, ++ skel->bss->cpuperf_target_max); ++ fflush(stdout); ++ sleep(1); ++ } ++ ++ bpf_link__destroy(link); ++ UEI_REPORT(skel, uei); ++ scx_qmap__destroy(skel); ++ /* ++ * scx_qmap implements ops.cpu_on/offline() and doesn't need to restart ++ * on CPU hotplug events. ++ */ ++ return 0; ++} +diff --git a/tools/sched_ext/scx_show_state.py b/tools/sched_ext/scx_show_state.py +new file mode 100644 +index 000000000000..d457d2a74e1e +--- /dev/null ++++ b/tools/sched_ext/scx_show_state.py +@@ -0,0 +1,39 @@ ++#!/usr/bin/env drgn ++# ++# Copyright (C) 2024 Tejun Heo ++# Copyright (C) 2024 Meta Platforms, Inc. and affiliates. ++ ++desc = """ ++This is a drgn script to show the current sched_ext state. ++For more info on drgn, visit https://github.com/osandov/drgn. ++""" ++ ++import drgn ++import sys ++ ++def err(s): ++ print(s, file=sys.stderr, flush=True) ++ sys.exit(1) ++ ++def read_int(name): ++ return int(prog[name].value_()) ++ ++def read_atomic(name): ++ return prog[name].counter.value_() ++ ++def read_static_key(name): ++ return prog[name].key.enabled.counter.value_() ++ ++def ops_state_str(state): ++ return prog['scx_ops_enable_state_str'][state].string_().decode() ++ ++ops = prog['scx_ops'] ++enable_state = read_atomic("scx_ops_enable_state_var") ++ ++print(f'ops : {ops.name.string_().decode()}') ++print(f'enabled : {read_static_key("__scx_ops_enabled")}') ++print(f'switching_all : {read_int("scx_switching_all")}') ++print(f'switched_all : {read_static_key("__scx_switched_all")}') ++print(f'enable_state : {ops_state_str(enable_state)} ({enable_state})') ++print(f'bypass_depth : {read_atomic("scx_ops_bypass_depth")}') ++print(f'nr_rejected : {read_atomic("scx_nr_rejected")}') +diff --git a/tools/sched_ext/scx_simple.bpf.c b/tools/sched_ext/scx_simple.bpf.c +new file mode 100644 +index 000000000000..ed7e8d535fc5 +--- /dev/null ++++ b/tools/sched_ext/scx_simple.bpf.c +@@ -0,0 +1,156 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A simple scheduler. ++ * ++ * By default, it operates as a simple global weighted vtime scheduler and can ++ * be switched to FIFO scheduling. It also demonstrates the following niceties. ++ * ++ * - Statistics tracking how many tasks are queued to local and global dsq's. ++ * - Termination notification for userspace. ++ * ++ * While very simple, this scheduler should work reasonably well on CPUs with a ++ * uniform L3 cache topology. While preemption is not implemented, the fact that ++ * the scheduling queue is shared across all CPUs means that whatever is at the ++ * front of the queue is likely to be executed fairly quickly given enough ++ * number of CPUs. The FIFO scheduling mode may be beneficial to some workloads ++ * but comes with the usual problems with FIFO scheduling where saturating ++ * threads can easily drown out interactive ones. ++ * ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++const volatile bool fifo_sched; ++ ++static u64 vtime_now; ++UEI_DEFINE(uei); ++ ++/* ++ * Built-in DSQs such as SCX_DSQ_GLOBAL cannot be used as priority queues ++ * (meaning, cannot be dispatched to with scx_bpf_dispatch_vtime()). We ++ * therefore create a separate DSQ with ID 0 that we dispatch to and consume ++ * from. If scx_simple only supported global FIFO scheduling, then we could ++ * just use SCX_DSQ_GLOBAL. ++ */ ++#define SHARED_DSQ 0 ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); ++ __uint(key_size, sizeof(u32)); ++ __uint(value_size, sizeof(u64)); ++ __uint(max_entries, 2); /* [local, global] */ ++} stats SEC(".maps"); ++ ++static void stat_inc(u32 idx) ++{ ++ u64 *cnt_p = bpf_map_lookup_elem(&stats, &idx); ++ if (cnt_p) ++ (*cnt_p)++; ++} ++ ++static inline bool vtime_before(u64 a, u64 b) ++{ ++ return (s64)(a - b) < 0; ++} ++ ++s32 BPF_STRUCT_OPS(simple_select_cpu, struct task_struct *p, s32 prev_cpu, u64 wake_flags) ++{ ++ bool is_idle = false; ++ s32 cpu; ++ ++ cpu = scx_bpf_select_cpu_dfl(p, prev_cpu, wake_flags, &is_idle); ++ if (is_idle) { ++ stat_inc(0); /* count local queueing */ ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL, SCX_SLICE_DFL, 0); ++ } ++ ++ return cpu; ++} ++ ++void BPF_STRUCT_OPS(simple_enqueue, struct task_struct *p, u64 enq_flags) ++{ ++ stat_inc(1); /* count global queueing */ ++ ++ if (fifo_sched) { ++ scx_bpf_dispatch(p, SHARED_DSQ, SCX_SLICE_DFL, enq_flags); ++ } else { ++ u64 vtime = p->scx.dsq_vtime; ++ ++ /* ++ * Limit the amount of budget that an idling task can accumulate ++ * to one slice. ++ */ ++ if (vtime_before(vtime, vtime_now - SCX_SLICE_DFL)) ++ vtime = vtime_now - SCX_SLICE_DFL; ++ ++ scx_bpf_dispatch_vtime(p, SHARED_DSQ, SCX_SLICE_DFL, vtime, ++ enq_flags); ++ } ++} ++ ++void BPF_STRUCT_OPS(simple_dispatch, s32 cpu, struct task_struct *prev) ++{ ++ scx_bpf_consume(SHARED_DSQ); ++} ++ ++void BPF_STRUCT_OPS(simple_running, struct task_struct *p) ++{ ++ if (fifo_sched) ++ return; ++ ++ /* ++ * Global vtime always progresses forward as tasks start executing. The ++ * test and update can be performed concurrently from multiple CPUs and ++ * thus racy. Any error should be contained and temporary. Let's just ++ * live with it. ++ */ ++ if (vtime_before(vtime_now, p->scx.dsq_vtime)) ++ vtime_now = p->scx.dsq_vtime; ++} ++ ++void BPF_STRUCT_OPS(simple_stopping, struct task_struct *p, bool runnable) ++{ ++ if (fifo_sched) ++ return; ++ ++ /* ++ * Scale the execution time by the inverse of the weight and charge. ++ * ++ * Note that the default yield implementation yields by setting ++ * @p->scx.slice to zero and the following would treat the yielding task ++ * as if it has consumed all its slice. If this penalizes yielding tasks ++ * too much, determine the execution time by taking explicit timestamps ++ * instead of depending on @p->scx.slice. ++ */ ++ p->scx.dsq_vtime += (SCX_SLICE_DFL - p->scx.slice) * 100 / p->scx.weight; ++} ++ ++void BPF_STRUCT_OPS(simple_enable, struct task_struct *p) ++{ ++ p->scx.dsq_vtime = vtime_now; ++} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(simple_init) ++{ ++ return scx_bpf_create_dsq(SHARED_DSQ, -1); ++} ++ ++void BPF_STRUCT_OPS(simple_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SCX_OPS_DEFINE(simple_ops, ++ .select_cpu = (void *)simple_select_cpu, ++ .enqueue = (void *)simple_enqueue, ++ .dispatch = (void *)simple_dispatch, ++ .running = (void *)simple_running, ++ .stopping = (void *)simple_stopping, ++ .enable = (void *)simple_enable, ++ .init = (void *)simple_init, ++ .exit = (void *)simple_exit, ++ .name = "simple"); +diff --git a/tools/sched_ext/scx_simple.c b/tools/sched_ext/scx_simple.c +new file mode 100644 +index 000000000000..76d83199545c +--- /dev/null ++++ b/tools/sched_ext/scx_simple.c +@@ -0,0 +1,107 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2022 Tejun Heo ++ * Copyright (c) 2022 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include ++#include "scx_simple.bpf.skel.h" ++ ++const char help_fmt[] = ++"A simple sched_ext scheduler.\n" ++"\n" ++"See the top-level comment in .bpf.c for more details.\n" ++"\n" ++"Usage: %s [-f] [-v]\n" ++"\n" ++" -f Use FIFO scheduling instead of weighted vtime scheduling\n" ++" -v Print libbpf debug messages\n" ++" -h Display this help and exit\n"; ++ ++static bool verbose; ++static volatile int exit_req; ++ ++static int libbpf_print_fn(enum libbpf_print_level level, const char *format, va_list args) ++{ ++ if (level == LIBBPF_DEBUG && !verbose) ++ return 0; ++ return vfprintf(stderr, format, args); ++} ++ ++static void sigint_handler(int simple) ++{ ++ exit_req = 1; ++} ++ ++static void read_stats(struct scx_simple *skel, __u64 *stats) ++{ ++ int nr_cpus = libbpf_num_possible_cpus(); ++ __u64 cnts[2][nr_cpus]; ++ __u32 idx; ++ ++ memset(stats, 0, sizeof(stats[0]) * 2); ++ ++ for (idx = 0; idx < 2; idx++) { ++ int ret, cpu; ++ ++ ret = bpf_map_lookup_elem(bpf_map__fd(skel->maps.stats), ++ &idx, cnts[idx]); ++ if (ret < 0) ++ continue; ++ for (cpu = 0; cpu < nr_cpus; cpu++) ++ stats[idx] += cnts[idx][cpu]; ++ } ++} ++ ++int main(int argc, char **argv) ++{ ++ struct scx_simple *skel; ++ struct bpf_link *link; ++ __u32 opt; ++ __u64 ecode; ++ ++ libbpf_set_print(libbpf_print_fn); ++ signal(SIGINT, sigint_handler); ++ signal(SIGTERM, sigint_handler); ++restart: ++ skel = SCX_OPS_OPEN(simple_ops, scx_simple); ++ ++ while ((opt = getopt(argc, argv, "fvh")) != -1) { ++ switch (opt) { ++ case 'f': ++ skel->rodata->fifo_sched = true; ++ break; ++ case 'v': ++ verbose = true; ++ break; ++ default: ++ fprintf(stderr, help_fmt, basename(argv[0])); ++ return opt != 'h'; ++ } ++ } ++ ++ SCX_OPS_LOAD(skel, simple_ops, scx_simple, uei); ++ link = SCX_OPS_ATTACH(skel, simple_ops, scx_simple); ++ ++ while (!exit_req && !UEI_EXITED(skel, uei)) { ++ __u64 stats[2]; ++ ++ read_stats(skel, stats); ++ printf("local=%llu global=%llu\n", stats[0], stats[1]); ++ fflush(stdout); ++ sleep(1); ++ } ++ ++ bpf_link__destroy(link); ++ ecode = UEI_REPORT(skel, uei); ++ scx_simple__destroy(skel); ++ ++ if (UEI_ECODE_RESTART(ecode)) ++ goto restart; ++ return 0; ++} +diff --git a/tools/testing/selftests/sched_ext/.gitignore b/tools/testing/selftests/sched_ext/.gitignore +new file mode 100644 +index 000000000000..ae5491a114c0 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/.gitignore +@@ -0,0 +1,6 @@ ++* ++!*.c ++!*.h ++!Makefile ++!.gitignore ++!config +diff --git a/tools/testing/selftests/sched_ext/Makefile b/tools/testing/selftests/sched_ext/Makefile +new file mode 100644 +index 000000000000..0754a2c110a1 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/Makefile +@@ -0,0 +1,218 @@ ++# SPDX-License-Identifier: GPL-2.0 ++# Copyright (c) 2022 Meta Platforms, Inc. and affiliates. ++include ../../../build/Build.include ++include ../../../scripts/Makefile.arch ++include ../../../scripts/Makefile.include ++include ../lib.mk ++ ++ifneq ($(LLVM),) ++ifneq ($(filter %/,$(LLVM)),) ++LLVM_PREFIX := $(LLVM) ++else ifneq ($(filter -%,$(LLVM)),) ++LLVM_SUFFIX := $(LLVM) ++endif ++ ++CC := $(LLVM_PREFIX)clang$(LLVM_SUFFIX) $(CLANG_FLAGS) -fintegrated-as ++else ++CC := gcc ++endif # LLVM ++ ++ifneq ($(CROSS_COMPILE),) ++$(error CROSS_COMPILE not supported for scx selftests) ++endif # CROSS_COMPILE ++ ++CURDIR := $(abspath .) ++REPOROOT := $(abspath ../../../..) ++TOOLSDIR := $(REPOROOT)/tools ++LIBDIR := $(TOOLSDIR)/lib ++BPFDIR := $(LIBDIR)/bpf ++TOOLSINCDIR := $(TOOLSDIR)/include ++BPFTOOLDIR := $(TOOLSDIR)/bpf/bpftool ++APIDIR := $(TOOLSINCDIR)/uapi ++GENDIR := $(REPOROOT)/include/generated ++GENHDR := $(GENDIR)/autoconf.h ++SCXTOOLSDIR := $(TOOLSDIR)/sched_ext ++SCXTOOLSINCDIR := $(TOOLSDIR)/sched_ext/include ++ ++OUTPUT_DIR := $(CURDIR)/build ++OBJ_DIR := $(OUTPUT_DIR)/obj ++INCLUDE_DIR := $(OUTPUT_DIR)/include ++BPFOBJ_DIR := $(OBJ_DIR)/libbpf ++SCXOBJ_DIR := $(OBJ_DIR)/sched_ext ++BPFOBJ := $(BPFOBJ_DIR)/libbpf.a ++LIBBPF_OUTPUT := $(OBJ_DIR)/libbpf/libbpf.a ++DEFAULT_BPFTOOL := $(OUTPUT_DIR)/sbin/bpftool ++HOST_BUILD_DIR := $(OBJ_DIR) ++HOST_OUTPUT_DIR := $(OUTPUT_DIR) ++ ++VMLINUX_BTF_PATHS ?= ../../../../vmlinux \ ++ /sys/kernel/btf/vmlinux \ ++ /boot/vmlinux-$(shell uname -r) ++VMLINUX_BTF ?= $(abspath $(firstword $(wildcard $(VMLINUX_BTF_PATHS)))) ++ifeq ($(VMLINUX_BTF),) ++$(error Cannot find a vmlinux for VMLINUX_BTF at any of "$(VMLINUX_BTF_PATHS)") ++endif ++ ++BPFTOOL ?= $(DEFAULT_BPFTOOL) ++ ++ifneq ($(wildcard $(GENHDR)),) ++ GENFLAGS := -DHAVE_GENHDR ++endif ++ ++CFLAGS += -g -O2 -rdynamic -pthread -Wall -Werror $(GENFLAGS) \ ++ -I$(INCLUDE_DIR) -I$(GENDIR) -I$(LIBDIR) \ ++ -I$(TOOLSINCDIR) -I$(APIDIR) -I$(CURDIR)/include -I$(SCXTOOLSINCDIR) ++ ++# Silence some warnings when compiled with clang ++ifneq ($(LLVM),) ++CFLAGS += -Wno-unused-command-line-argument ++endif ++ ++LDFLAGS = -lelf -lz -lpthread -lzstd ++ ++IS_LITTLE_ENDIAN = $(shell $(CC) -dM -E - &1 \ ++ | sed -n '/<...> search starts here:/,/End of search list./{ s| \(/.*\)|-idirafter \1|p }') \ ++$(shell $(1) -dM -E - $@ ++else ++ $(call msg,CP,,$@) ++ $(Q)cp "$(VMLINUX_H)" $@ ++endif ++ ++$(SCXOBJ_DIR)/%.bpf.o: %.bpf.c $(INCLUDE_DIR)/vmlinux.h | $(BPFOBJ) $(SCXOBJ_DIR) ++ $(call msg,CLNG-BPF,,$(notdir $@)) ++ $(Q)$(CLANG) $(BPF_CFLAGS) -target bpf -c $< -o $@ ++ ++$(INCLUDE_DIR)/%.bpf.skel.h: $(SCXOBJ_DIR)/%.bpf.o $(INCLUDE_DIR)/vmlinux.h $(BPFTOOL) | $(INCLUDE_DIR) ++ $(eval sched=$(notdir $@)) ++ $(call msg,GEN-SKEL,,$(sched)) ++ $(Q)$(BPFTOOL) gen object $(<:.o=.linked1.o) $< ++ $(Q)$(BPFTOOL) gen object $(<:.o=.linked2.o) $(<:.o=.linked1.o) ++ $(Q)$(BPFTOOL) gen object $(<:.o=.linked3.o) $(<:.o=.linked2.o) ++ $(Q)diff $(<:.o=.linked2.o) $(<:.o=.linked3.o) ++ $(Q)$(BPFTOOL) gen skeleton $(<:.o=.linked3.o) name $(subst .bpf.skel.h,,$(sched)) > $@ ++ $(Q)$(BPFTOOL) gen subskeleton $(<:.o=.linked3.o) name $(subst .bpf.skel.h,,$(sched)) > $(@:.skel.h=.subskel.h) ++ ++################ ++# C schedulers # ++################ ++ ++override define CLEAN ++ rm -rf $(OUTPUT_DIR) ++ rm -f *.o *.bpf.o *.bpf.skel.h *.bpf.subskel.h ++ rm -f $(TEST_GEN_PROGS) ++ rm -f runner ++endef ++ ++# Every testcase takes all of the BPF progs are dependencies by default. This ++# allows testcases to load any BPF scheduler, which is useful for testcases ++# that don't need their own prog to run their test. ++all_test_bpfprogs := $(foreach prog,$(wildcard *.bpf.c),$(INCLUDE_DIR)/$(patsubst %.c,%.skel.h,$(prog))) ++ ++auto-test-targets := \ ++ create_dsq \ ++ enq_last_no_enq_fails \ ++ enq_select_cpu_fails \ ++ ddsp_bogus_dsq_fail \ ++ ddsp_vtimelocal_fail \ ++ dsp_local_on \ ++ exit \ ++ hotplug \ ++ init_enable_count \ ++ maximal \ ++ maybe_null \ ++ minimal \ ++ prog_run \ ++ reload_loop \ ++ select_cpu_dfl \ ++ select_cpu_dfl_nodispatch \ ++ select_cpu_dispatch \ ++ select_cpu_dispatch_bad_dsq \ ++ select_cpu_dispatch_dbl_dsp \ ++ select_cpu_vtime \ ++ test_example \ ++ ++testcase-targets := $(addsuffix .o,$(addprefix $(SCXOBJ_DIR)/,$(auto-test-targets))) ++ ++$(SCXOBJ_DIR)/runner.o: runner.c | $(SCXOBJ_DIR) ++ $(CC) $(CFLAGS) -c $< -o $@ ++ ++# Create all of the test targets object files, whose testcase objects will be ++# registered into the runner in ELF constructors. ++# ++# Note that we must do double expansion here in order to support conditionally ++# compiling BPF object files only if one is present, as the wildcard Make ++# function doesn't support using implicit rules otherwise. ++$(testcase-targets): $(SCXOBJ_DIR)/%.o: %.c $(SCXOBJ_DIR)/runner.o $(all_test_bpfprogs) | $(SCXOBJ_DIR) ++ $(eval test=$(patsubst %.o,%.c,$(notdir $@))) ++ $(CC) $(CFLAGS) -c $< -o $@ $(SCXOBJ_DIR)/runner.o ++ ++$(SCXOBJ_DIR)/util.o: util.c | $(SCXOBJ_DIR) ++ $(CC) $(CFLAGS) -c $< -o $@ ++ ++runner: $(SCXOBJ_DIR)/runner.o $(SCXOBJ_DIR)/util.o $(BPFOBJ) $(testcase-targets) ++ @echo "$(testcase-targets)" ++ $(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) ++ ++TEST_GEN_PROGS := runner ++ ++all: runner ++ ++.PHONY: all clean help ++ ++.DEFAULT_GOAL := all ++ ++.DELETE_ON_ERROR: ++ ++.SECONDARY: +diff --git a/tools/testing/selftests/sched_ext/config b/tools/testing/selftests/sched_ext/config +new file mode 100644 +index 000000000000..0de9b4ee249d +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/config +@@ -0,0 +1,9 @@ ++CONFIG_SCHED_DEBUG=y ++CONFIG_SCHED_CLASS_EXT=y ++CONFIG_CGROUPS=y ++CONFIG_CGROUP_SCHED=y ++CONFIG_EXT_GROUP_SCHED=y ++CONFIG_BPF=y ++CONFIG_BPF_SYSCALL=y ++CONFIG_DEBUG_INFO=y ++CONFIG_DEBUG_INFO_BTF=y +diff --git a/tools/testing/selftests/sched_ext/create_dsq.bpf.c b/tools/testing/selftests/sched_ext/create_dsq.bpf.c +new file mode 100644 +index 000000000000..23f79ed343f0 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/create_dsq.bpf.c +@@ -0,0 +1,58 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Create and destroy DSQs in a loop. ++ * ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++void BPF_STRUCT_OPS(create_dsq_exit_task, struct task_struct *p, ++ struct scx_exit_task_args *args) ++{ ++ scx_bpf_destroy_dsq(p->pid); ++} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(create_dsq_init_task, struct task_struct *p, ++ struct scx_init_task_args *args) ++{ ++ s32 err; ++ ++ err = scx_bpf_create_dsq(p->pid, -1); ++ if (err) ++ scx_bpf_error("Failed to create DSQ for %s[%d]", ++ p->comm, p->pid); ++ ++ return err; ++} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(create_dsq_init) ++{ ++ u32 i; ++ s32 err; ++ ++ bpf_for(i, 0, 1024) { ++ err = scx_bpf_create_dsq(i, -1); ++ if (err) { ++ scx_bpf_error("Failed to create DSQ %d", i); ++ return 0; ++ } ++ } ++ ++ bpf_for(i, 0, 1024) { ++ scx_bpf_destroy_dsq(i); ++ } ++ ++ return 0; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops create_dsq_ops = { ++ .init_task = create_dsq_init_task, ++ .exit_task = create_dsq_exit_task, ++ .init = create_dsq_init, ++ .name = "create_dsq", ++}; +diff --git a/tools/testing/selftests/sched_ext/create_dsq.c b/tools/testing/selftests/sched_ext/create_dsq.c +new file mode 100644 +index 000000000000..fa946d9146d4 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/create_dsq.c +@@ -0,0 +1,57 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include "create_dsq.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct create_dsq *skel; ++ ++ skel = create_dsq__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load skel"); ++ return SCX_TEST_FAIL; ++ } ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct create_dsq *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.create_dsq_ops); ++ if (!link) { ++ SCX_ERR("Failed to attach scheduler"); ++ return SCX_TEST_FAIL; ++ } ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct create_dsq *skel = ctx; ++ ++ create_dsq__destroy(skel); ++} ++ ++struct scx_test create_dsq = { ++ .name = "create_dsq", ++ .description = "Create and destroy a dsq in a loop", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&create_dsq) +diff --git a/tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.bpf.c b/tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.bpf.c +new file mode 100644 +index 000000000000..e97ad41d354a +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.bpf.c +@@ -0,0 +1,42 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++UEI_DEFINE(uei); ++ ++s32 BPF_STRUCT_OPS(ddsp_bogus_dsq_fail_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ s32 cpu = scx_bpf_pick_idle_cpu(p->cpus_ptr, 0); ++ ++ if (cpu >= 0) { ++ /* ++ * If we dispatch to a bogus DSQ that will fall back to the ++ * builtin global DSQ, we fail gracefully. ++ */ ++ scx_bpf_dispatch_vtime(p, 0xcafef00d, SCX_SLICE_DFL, ++ p->scx.dsq_vtime, 0); ++ return cpu; ++ } ++ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(ddsp_bogus_dsq_fail_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops ddsp_bogus_dsq_fail_ops = { ++ .select_cpu = ddsp_bogus_dsq_fail_select_cpu, ++ .exit = ddsp_bogus_dsq_fail_exit, ++ .name = "ddsp_bogus_dsq_fail", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.c b/tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.c +new file mode 100644 +index 000000000000..e65d22f23f3b +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/ddsp_bogus_dsq_fail.c +@@ -0,0 +1,57 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "ddsp_bogus_dsq_fail.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct ddsp_bogus_dsq_fail *skel; ++ ++ skel = ddsp_bogus_dsq_fail__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct ddsp_bogus_dsq_fail *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.ddsp_bogus_dsq_fail_ops); ++ SCX_FAIL_IF(!link, "Failed to attach struct_ops"); ++ ++ sleep(1); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_ERROR)); ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct ddsp_bogus_dsq_fail *skel = ctx; ++ ++ ddsp_bogus_dsq_fail__destroy(skel); ++} ++ ++struct scx_test ddsp_bogus_dsq_fail = { ++ .name = "ddsp_bogus_dsq_fail", ++ .description = "Verify we gracefully fail, and fall back to using a " ++ "built-in DSQ, if we do a direct dispatch to an invalid" ++ " DSQ in ops.select_cpu()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&ddsp_bogus_dsq_fail) +diff --git a/tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.bpf.c b/tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.bpf.c +new file mode 100644 +index 000000000000..dde7e7dafbfb +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.bpf.c +@@ -0,0 +1,39 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++UEI_DEFINE(uei); ++ ++s32 BPF_STRUCT_OPS(ddsp_vtimelocal_fail_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ s32 cpu = scx_bpf_pick_idle_cpu(p->cpus_ptr, 0); ++ ++ if (cpu >= 0) { ++ /* Shouldn't be allowed to vtime dispatch to a builtin DSQ. */ ++ scx_bpf_dispatch_vtime(p, SCX_DSQ_LOCAL, SCX_SLICE_DFL, ++ p->scx.dsq_vtime, 0); ++ return cpu; ++ } ++ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(ddsp_vtimelocal_fail_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops ddsp_vtimelocal_fail_ops = { ++ .select_cpu = ddsp_vtimelocal_fail_select_cpu, ++ .exit = ddsp_vtimelocal_fail_exit, ++ .name = "ddsp_vtimelocal_fail", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.c b/tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.c +new file mode 100644 +index 000000000000..abafee587cd6 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/ddsp_vtimelocal_fail.c +@@ -0,0 +1,56 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++#include ++#include ++#include ++#include "ddsp_vtimelocal_fail.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct ddsp_vtimelocal_fail *skel; ++ ++ skel = ddsp_vtimelocal_fail__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct ddsp_vtimelocal_fail *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.ddsp_vtimelocal_fail_ops); ++ SCX_FAIL_IF(!link, "Failed to attach struct_ops"); ++ ++ sleep(1); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_ERROR)); ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct ddsp_vtimelocal_fail *skel = ctx; ++ ++ ddsp_vtimelocal_fail__destroy(skel); ++} ++ ++struct scx_test ddsp_vtimelocal_fail = { ++ .name = "ddsp_vtimelocal_fail", ++ .description = "Verify we gracefully fail, and fall back to using a " ++ "built-in DSQ, if we do a direct vtime dispatch to a " ++ "built-in DSQ from DSQ in ops.select_cpu()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&ddsp_vtimelocal_fail) +diff --git a/tools/testing/selftests/sched_ext/dsp_local_on.bpf.c b/tools/testing/selftests/sched_ext/dsp_local_on.bpf.c +new file mode 100644 +index 000000000000..efb4672decb4 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/dsp_local_on.bpf.c +@@ -0,0 +1,65 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++const volatile s32 nr_cpus; ++ ++UEI_DEFINE(uei); ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_QUEUE); ++ __uint(max_entries, 8192); ++ __type(value, s32); ++} queue SEC(".maps"); ++ ++s32 BPF_STRUCT_OPS(dsp_local_on_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(dsp_local_on_enqueue, struct task_struct *p, ++ u64 enq_flags) ++{ ++ s32 pid = p->pid; ++ ++ if (bpf_map_push_elem(&queue, &pid, 0)) ++ scx_bpf_error("Failed to enqueue %s[%d]", p->comm, p->pid); ++} ++ ++void BPF_STRUCT_OPS(dsp_local_on_dispatch, s32 cpu, struct task_struct *prev) ++{ ++ s32 pid, target; ++ struct task_struct *p; ++ ++ if (bpf_map_pop_elem(&queue, &pid)) ++ return; ++ ++ p = bpf_task_from_pid(pid); ++ if (!p) ++ return; ++ ++ target = bpf_get_prandom_u32() % nr_cpus; ++ ++ scx_bpf_dispatch(p, SCX_DSQ_LOCAL_ON | target, SCX_SLICE_DFL, 0); ++ bpf_task_release(p); ++} ++ ++void BPF_STRUCT_OPS(dsp_local_on_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops dsp_local_on_ops = { ++ .select_cpu = dsp_local_on_select_cpu, ++ .enqueue = dsp_local_on_enqueue, ++ .dispatch = dsp_local_on_dispatch, ++ .exit = dsp_local_on_exit, ++ .name = "dsp_local_on", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/dsp_local_on.c b/tools/testing/selftests/sched_ext/dsp_local_on.c +new file mode 100644 +index 000000000000..472851b56854 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/dsp_local_on.c +@@ -0,0 +1,58 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include "dsp_local_on.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct dsp_local_on *skel; ++ ++ skel = dsp_local_on__open(); ++ SCX_FAIL_IF(!skel, "Failed to open"); ++ ++ skel->rodata->nr_cpus = libbpf_num_possible_cpus(); ++ SCX_FAIL_IF(dsp_local_on__load(skel), "Failed to load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct dsp_local_on *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.dsp_local_on_ops); ++ SCX_FAIL_IF(!link, "Failed to attach struct_ops"); ++ ++ /* Just sleeping is fine, plenty of scheduling events happening */ ++ sleep(1); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_ERROR)); ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct dsp_local_on *skel = ctx; ++ ++ dsp_local_on__destroy(skel); ++} ++ ++struct scx_test dsp_local_on = { ++ .name = "dsp_local_on", ++ .description = "Verify we can directly dispatch tasks to a local DSQs " ++ "from osp.dispatch()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&dsp_local_on) +diff --git a/tools/testing/selftests/sched_ext/enq_last_no_enq_fails.bpf.c b/tools/testing/selftests/sched_ext/enq_last_no_enq_fails.bpf.c +new file mode 100644 +index 000000000000..b0b99531d5d5 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/enq_last_no_enq_fails.bpf.c +@@ -0,0 +1,21 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates the behavior of direct dispatching with a default ++ * select_cpu implementation. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops enq_last_no_enq_fails_ops = { ++ .name = "enq_last_no_enq_fails", ++ /* Need to define ops.enqueue() with SCX_OPS_ENQ_LAST */ ++ .flags = SCX_OPS_ENQ_LAST, ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/enq_last_no_enq_fails.c b/tools/testing/selftests/sched_ext/enq_last_no_enq_fails.c +new file mode 100644 +index 000000000000..2a3eda5e2c0b +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/enq_last_no_enq_fails.c +@@ -0,0 +1,60 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "enq_last_no_enq_fails.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct enq_last_no_enq_fails *skel; ++ ++ skel = enq_last_no_enq_fails__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load skel"); ++ return SCX_TEST_FAIL; ++ } ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct enq_last_no_enq_fails *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.enq_last_no_enq_fails_ops); ++ if (link) { ++ SCX_ERR("Incorrectly succeeded in to attaching scheduler"); ++ return SCX_TEST_FAIL; ++ } ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct enq_last_no_enq_fails *skel = ctx; ++ ++ enq_last_no_enq_fails__destroy(skel); ++} ++ ++struct scx_test enq_last_no_enq_fails = { ++ .name = "enq_last_no_enq_fails", ++ .description = "Verify we fail to load a scheduler if we specify " ++ "the SCX_OPS_ENQ_LAST flag without defining " ++ "ops.enqueue()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&enq_last_no_enq_fails) +diff --git a/tools/testing/selftests/sched_ext/enq_select_cpu_fails.bpf.c b/tools/testing/selftests/sched_ext/enq_select_cpu_fails.bpf.c +new file mode 100644 +index 000000000000..b3dfc1033cd6 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/enq_select_cpu_fails.bpf.c +@@ -0,0 +1,43 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++/* Manually specify the signature until the kfunc is added to the scx repo. */ ++s32 scx_bpf_select_cpu_dfl(struct task_struct *p, s32 prev_cpu, u64 wake_flags, ++ bool *found) __ksym; ++ ++s32 BPF_STRUCT_OPS(enq_select_cpu_fails_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(enq_select_cpu_fails_enqueue, struct task_struct *p, ++ u64 enq_flags) ++{ ++ /* ++ * Need to initialize the variable or the verifier will fail to load. ++ * Improving these semantics is actively being worked on. ++ */ ++ bool found = false; ++ ++ /* Can only call from ops.select_cpu() */ ++ scx_bpf_select_cpu_dfl(p, 0, 0, &found); ++ ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, enq_flags); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops enq_select_cpu_fails_ops = { ++ .select_cpu = enq_select_cpu_fails_select_cpu, ++ .enqueue = enq_select_cpu_fails_enqueue, ++ .name = "enq_select_cpu_fails", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/enq_select_cpu_fails.c b/tools/testing/selftests/sched_ext/enq_select_cpu_fails.c +new file mode 100644 +index 000000000000..dd1350e5f002 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/enq_select_cpu_fails.c +@@ -0,0 +1,61 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "enq_select_cpu_fails.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct enq_select_cpu_fails *skel; ++ ++ skel = enq_select_cpu_fails__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load skel"); ++ return SCX_TEST_FAIL; ++ } ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct enq_select_cpu_fails *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.enq_select_cpu_fails_ops); ++ if (!link) { ++ SCX_ERR("Failed to attach scheduler"); ++ return SCX_TEST_FAIL; ++ } ++ ++ sleep(1); ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct enq_select_cpu_fails *skel = ctx; ++ ++ enq_select_cpu_fails__destroy(skel); ++} ++ ++struct scx_test enq_select_cpu_fails = { ++ .name = "enq_select_cpu_fails", ++ .description = "Verify we fail to call scx_bpf_select_cpu_dfl() " ++ "from ops.enqueue()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&enq_select_cpu_fails) +diff --git a/tools/testing/selftests/sched_ext/exit.bpf.c b/tools/testing/selftests/sched_ext/exit.bpf.c +new file mode 100644 +index 000000000000..ae12ddaac921 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/exit.bpf.c +@@ -0,0 +1,84 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++#include "exit_test.h" ++ ++const volatile int exit_point; ++UEI_DEFINE(uei); ++ ++#define EXIT_CLEANLY() scx_bpf_exit(exit_point, "%d", exit_point) ++ ++s32 BPF_STRUCT_OPS(exit_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ bool found; ++ ++ if (exit_point == EXIT_SELECT_CPU) ++ EXIT_CLEANLY(); ++ ++ return scx_bpf_select_cpu_dfl(p, prev_cpu, wake_flags, &found); ++} ++ ++void BPF_STRUCT_OPS(exit_enqueue, struct task_struct *p, u64 enq_flags) ++{ ++ if (exit_point == EXIT_ENQUEUE) ++ EXIT_CLEANLY(); ++ ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, enq_flags); ++} ++ ++void BPF_STRUCT_OPS(exit_dispatch, s32 cpu, struct task_struct *p) ++{ ++ if (exit_point == EXIT_DISPATCH) ++ EXIT_CLEANLY(); ++ ++ scx_bpf_consume(SCX_DSQ_GLOBAL); ++} ++ ++void BPF_STRUCT_OPS(exit_enable, struct task_struct *p) ++{ ++ if (exit_point == EXIT_ENABLE) ++ EXIT_CLEANLY(); ++} ++ ++s32 BPF_STRUCT_OPS(exit_init_task, struct task_struct *p, ++ struct scx_init_task_args *args) ++{ ++ if (exit_point == EXIT_INIT_TASK) ++ EXIT_CLEANLY(); ++ ++ return 0; ++} ++ ++void BPF_STRUCT_OPS(exit_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(exit_init) ++{ ++ if (exit_point == EXIT_INIT) ++ EXIT_CLEANLY(); ++ ++ return 0; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops exit_ops = { ++ .select_cpu = exit_select_cpu, ++ .enqueue = exit_enqueue, ++ .dispatch = exit_dispatch, ++ .init_task = exit_init_task, ++ .enable = exit_enable, ++ .exit = exit_exit, ++ .init = exit_init, ++ .name = "exit", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/exit.c b/tools/testing/selftests/sched_ext/exit.c +new file mode 100644 +index 000000000000..31bcd06e21cd +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/exit.c +@@ -0,0 +1,55 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include "exit.bpf.skel.h" ++#include "scx_test.h" ++ ++#include "exit_test.h" ++ ++static enum scx_test_status run(void *ctx) ++{ ++ enum exit_test_case tc; ++ ++ for (tc = 0; tc < NUM_EXITS; tc++) { ++ struct exit *skel; ++ struct bpf_link *link; ++ char buf[16]; ++ ++ skel = exit__open(); ++ skel->rodata->exit_point = tc; ++ exit__load(skel); ++ link = bpf_map__attach_struct_ops(skel->maps.exit_ops); ++ if (!link) { ++ SCX_ERR("Failed to attach scheduler"); ++ exit__destroy(skel); ++ return SCX_TEST_FAIL; ++ } ++ ++ /* Assumes uei.kind is written last */ ++ while (skel->data->uei.kind == EXIT_KIND(SCX_EXIT_NONE)) ++ sched_yield(); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_UNREG_BPF)); ++ SCX_EQ(skel->data->uei.exit_code, tc); ++ sprintf(buf, "%d", tc); ++ SCX_ASSERT(!strcmp(skel->data->uei.msg, buf)); ++ bpf_link__destroy(link); ++ exit__destroy(skel); ++ } ++ ++ return SCX_TEST_PASS; ++} ++ ++struct scx_test exit_test = { ++ .name = "exit", ++ .description = "Verify we can cleanly exit a scheduler in multiple places", ++ .run = run, ++}; ++REGISTER_SCX_TEST(&exit_test) +diff --git a/tools/testing/selftests/sched_ext/exit_test.h b/tools/testing/selftests/sched_ext/exit_test.h +new file mode 100644 +index 000000000000..94f0268b9cb8 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/exit_test.h +@@ -0,0 +1,20 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#ifndef __EXIT_TEST_H__ ++#define __EXIT_TEST_H__ ++ ++enum exit_test_case { ++ EXIT_SELECT_CPU, ++ EXIT_ENQUEUE, ++ EXIT_DISPATCH, ++ EXIT_ENABLE, ++ EXIT_INIT_TASK, ++ EXIT_INIT, ++ NUM_EXITS, ++}; ++ ++#endif // # __EXIT_TEST_H__ +diff --git a/tools/testing/selftests/sched_ext/hotplug.bpf.c b/tools/testing/selftests/sched_ext/hotplug.bpf.c +new file mode 100644 +index 000000000000..8f2601db39f3 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/hotplug.bpf.c +@@ -0,0 +1,61 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++#include "hotplug_test.h" ++ ++UEI_DEFINE(uei); ++ ++void BPF_STRUCT_OPS(hotplug_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++static void exit_from_hotplug(s32 cpu, bool onlining) ++{ ++ /* ++ * Ignored, just used to verify that we can invoke blocking kfuncs ++ * from the hotplug path. ++ */ ++ scx_bpf_create_dsq(0, -1); ++ ++ s64 code = SCX_ECODE_ACT_RESTART | HOTPLUG_EXIT_RSN; ++ ++ if (onlining) ++ code |= HOTPLUG_ONLINING; ++ ++ scx_bpf_exit(code, "hotplug event detected (%d going %s)", cpu, ++ onlining ? "online" : "offline"); ++} ++ ++void BPF_STRUCT_OPS_SLEEPABLE(hotplug_cpu_online, s32 cpu) ++{ ++ exit_from_hotplug(cpu, true); ++} ++ ++void BPF_STRUCT_OPS_SLEEPABLE(hotplug_cpu_offline, s32 cpu) ++{ ++ exit_from_hotplug(cpu, false); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops hotplug_cb_ops = { ++ .cpu_online = hotplug_cpu_online, ++ .cpu_offline = hotplug_cpu_offline, ++ .exit = hotplug_exit, ++ .name = "hotplug_cbs", ++ .timeout_ms = 1000U, ++}; ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops hotplug_nocb_ops = { ++ .exit = hotplug_exit, ++ .name = "hotplug_nocbs", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/hotplug.c b/tools/testing/selftests/sched_ext/hotplug.c +new file mode 100644 +index 000000000000..87bf220b1bce +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/hotplug.c +@@ -0,0 +1,168 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#include "hotplug_test.h" ++#include "hotplug.bpf.skel.h" ++#include "scx_test.h" ++#include "util.h" ++ ++const char *online_path = "/sys/devices/system/cpu/cpu1/online"; ++ ++static bool is_cpu_online(void) ++{ ++ return file_read_long(online_path) > 0; ++} ++ ++static void toggle_online_status(bool online) ++{ ++ long val = online ? 1 : 0; ++ int ret; ++ ++ ret = file_write_long(online_path, val); ++ if (ret != 0) ++ fprintf(stderr, "Failed to bring CPU %s (%s)", ++ online ? "online" : "offline", strerror(errno)); ++} ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ if (!is_cpu_online()) ++ return SCX_TEST_SKIP; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status test_hotplug(bool onlining, bool cbs_defined) ++{ ++ struct hotplug *skel; ++ struct bpf_link *link; ++ long kind, code; ++ ++ SCX_ASSERT(is_cpu_online()); ++ ++ skel = hotplug__open_and_load(); ++ SCX_ASSERT(skel); ++ ++ /* Testing the offline -> online path, so go offline before starting */ ++ if (onlining) ++ toggle_online_status(0); ++ ++ if (cbs_defined) { ++ kind = SCX_KIND_VAL(SCX_EXIT_UNREG_BPF); ++ code = SCX_ECODE_VAL(SCX_ECODE_ACT_RESTART) | HOTPLUG_EXIT_RSN; ++ if (onlining) ++ code |= HOTPLUG_ONLINING; ++ } else { ++ kind = SCX_KIND_VAL(SCX_EXIT_UNREG_KERN); ++ code = SCX_ECODE_VAL(SCX_ECODE_ACT_RESTART) | ++ SCX_ECODE_VAL(SCX_ECODE_RSN_HOTPLUG); ++ } ++ ++ if (cbs_defined) ++ link = bpf_map__attach_struct_ops(skel->maps.hotplug_cb_ops); ++ else ++ link = bpf_map__attach_struct_ops(skel->maps.hotplug_nocb_ops); ++ ++ if (!link) { ++ SCX_ERR("Failed to attach scheduler"); ++ hotplug__destroy(skel); ++ return SCX_TEST_FAIL; ++ } ++ ++ toggle_online_status(onlining ? 1 : 0); ++ ++ while (!UEI_EXITED(skel, uei)) ++ sched_yield(); ++ ++ SCX_EQ(skel->data->uei.kind, kind); ++ SCX_EQ(UEI_REPORT(skel, uei), code); ++ ++ if (!onlining) ++ toggle_online_status(1); ++ ++ bpf_link__destroy(link); ++ hotplug__destroy(skel); ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status test_hotplug_attach(void) ++{ ++ struct hotplug *skel; ++ struct bpf_link *link; ++ enum scx_test_status status = SCX_TEST_PASS; ++ long kind, code; ++ ++ SCX_ASSERT(is_cpu_online()); ++ SCX_ASSERT(scx_hotplug_seq() > 0); ++ ++ skel = SCX_OPS_OPEN(hotplug_nocb_ops, hotplug); ++ SCX_ASSERT(skel); ++ ++ SCX_OPS_LOAD(skel, hotplug_nocb_ops, hotplug, uei); ++ ++ /* ++ * Take the CPU offline to increment the global hotplug seq, which ++ * should cause attach to fail due to us setting the hotplug seq above ++ */ ++ toggle_online_status(0); ++ link = bpf_map__attach_struct_ops(skel->maps.hotplug_nocb_ops); ++ ++ toggle_online_status(1); ++ ++ SCX_ASSERT(link); ++ while (!UEI_EXITED(skel, uei)) ++ sched_yield(); ++ ++ kind = SCX_KIND_VAL(SCX_EXIT_UNREG_KERN); ++ code = SCX_ECODE_VAL(SCX_ECODE_ACT_RESTART) | ++ SCX_ECODE_VAL(SCX_ECODE_RSN_HOTPLUG); ++ SCX_EQ(skel->data->uei.kind, kind); ++ SCX_EQ(UEI_REPORT(skel, uei), code); ++ ++ bpf_link__destroy(link); ++ hotplug__destroy(skel); ++ ++ return status; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ ++#define HP_TEST(__onlining, __cbs_defined) ({ \ ++ if (test_hotplug(__onlining, __cbs_defined) != SCX_TEST_PASS) \ ++ return SCX_TEST_FAIL; \ ++}) ++ ++ HP_TEST(true, true); ++ HP_TEST(false, true); ++ HP_TEST(true, false); ++ HP_TEST(false, false); ++ ++#undef HP_TEST ++ ++ return test_hotplug_attach(); ++} ++ ++static void cleanup(void *ctx) ++{ ++ toggle_online_status(1); ++} ++ ++struct scx_test hotplug_test = { ++ .name = "hotplug", ++ .description = "Verify hotplug behavior", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&hotplug_test) +diff --git a/tools/testing/selftests/sched_ext/hotplug_test.h b/tools/testing/selftests/sched_ext/hotplug_test.h +new file mode 100644 +index 000000000000..73d236f90787 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/hotplug_test.h +@@ -0,0 +1,15 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#ifndef __HOTPLUG_TEST_H__ ++#define __HOTPLUG_TEST_H__ ++ ++enum hotplug_test_flags { ++ HOTPLUG_EXIT_RSN = 1LLU << 0, ++ HOTPLUG_ONLINING = 1LLU << 1, ++}; ++ ++#endif // # __HOTPLUG_TEST_H__ +diff --git a/tools/testing/selftests/sched_ext/init_enable_count.bpf.c b/tools/testing/selftests/sched_ext/init_enable_count.bpf.c +new file mode 100644 +index 000000000000..47ea89a626c3 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/init_enable_count.bpf.c +@@ -0,0 +1,53 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that verifies that we do proper counting of init, enable, etc ++ * callbacks. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++u64 init_task_cnt, exit_task_cnt, enable_cnt, disable_cnt; ++u64 init_fork_cnt, init_transition_cnt; ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(cnt_init_task, struct task_struct *p, ++ struct scx_init_task_args *args) ++{ ++ __sync_fetch_and_add(&init_task_cnt, 1); ++ ++ if (args->fork) ++ __sync_fetch_and_add(&init_fork_cnt, 1); ++ else ++ __sync_fetch_and_add(&init_transition_cnt, 1); ++ ++ return 0; ++} ++ ++void BPF_STRUCT_OPS(cnt_exit_task, struct task_struct *p) ++{ ++ __sync_fetch_and_add(&exit_task_cnt, 1); ++} ++ ++void BPF_STRUCT_OPS(cnt_enable, struct task_struct *p) ++{ ++ __sync_fetch_and_add(&enable_cnt, 1); ++} ++ ++void BPF_STRUCT_OPS(cnt_disable, struct task_struct *p) ++{ ++ __sync_fetch_and_add(&disable_cnt, 1); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops init_enable_count_ops = { ++ .init_task = cnt_init_task, ++ .exit_task = cnt_exit_task, ++ .enable = cnt_enable, ++ .disable = cnt_disable, ++ .name = "init_enable_count", ++}; +diff --git a/tools/testing/selftests/sched_ext/init_enable_count.c b/tools/testing/selftests/sched_ext/init_enable_count.c +new file mode 100644 +index 000000000000..97d45f1e5597 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/init_enable_count.c +@@ -0,0 +1,166 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include ++#include ++#include "scx_test.h" ++#include "init_enable_count.bpf.skel.h" ++ ++#define SCHED_EXT 7 ++ ++static struct init_enable_count * ++open_load_prog(bool global) ++{ ++ struct init_enable_count *skel; ++ ++ skel = init_enable_count__open(); ++ SCX_BUG_ON(!skel, "Failed to open skel"); ++ ++ if (!global) ++ skel->struct_ops.init_enable_count_ops->flags |= SCX_OPS_SWITCH_PARTIAL; ++ ++ SCX_BUG_ON(init_enable_count__load(skel), "Failed to load skel"); ++ ++ return skel; ++} ++ ++static enum scx_test_status run_test(bool global) ++{ ++ struct init_enable_count *skel; ++ struct bpf_link *link; ++ const u32 num_children = 5, num_pre_forks = 1024; ++ int ret, i, status; ++ struct sched_param param = {}; ++ pid_t pids[num_pre_forks]; ++ ++ skel = open_load_prog(global); ++ ++ /* ++ * Fork a bunch of children before we attach the scheduler so that we ++ * ensure (at least in practical terms) that there are more tasks that ++ * transition from SCHED_OTHER -> SCHED_EXT than there are tasks that ++ * take the fork() path either below or in other processes. ++ */ ++ for (i = 0; i < num_pre_forks; i++) { ++ pids[i] = fork(); ++ SCX_FAIL_IF(pids[i] < 0, "Failed to fork child"); ++ if (pids[i] == 0) { ++ sleep(1); ++ exit(0); ++ } ++ } ++ ++ link = bpf_map__attach_struct_ops(skel->maps.init_enable_count_ops); ++ SCX_FAIL_IF(!link, "Failed to attach struct_ops"); ++ ++ for (i = 0; i < num_pre_forks; i++) { ++ SCX_FAIL_IF(waitpid(pids[i], &status, 0) != pids[i], ++ "Failed to wait for pre-forked child\n"); ++ ++ SCX_FAIL_IF(status != 0, "Pre-forked child %d exited with status %d\n", i, ++ status); ++ } ++ ++ bpf_link__destroy(link); ++ SCX_GE(skel->bss->init_task_cnt, num_pre_forks); ++ SCX_GE(skel->bss->exit_task_cnt, num_pre_forks); ++ ++ link = bpf_map__attach_struct_ops(skel->maps.init_enable_count_ops); ++ SCX_FAIL_IF(!link, "Failed to attach struct_ops"); ++ ++ /* SCHED_EXT children */ ++ for (i = 0; i < num_children; i++) { ++ pids[i] = fork(); ++ SCX_FAIL_IF(pids[i] < 0, "Failed to fork child"); ++ ++ if (pids[i] == 0) { ++ ret = sched_setscheduler(0, SCHED_EXT, ¶m); ++ SCX_BUG_ON(ret, "Failed to set sched to sched_ext"); ++ ++ /* ++ * Reset to SCHED_OTHER for half of them. Counts for ++ * everything should still be the same regardless, as ++ * ops.disable() is invoked even if a task is still on ++ * SCHED_EXT before it exits. ++ */ ++ if (i % 2 == 0) { ++ ret = sched_setscheduler(0, SCHED_OTHER, ¶m); ++ SCX_BUG_ON(ret, "Failed to reset sched to normal"); ++ } ++ exit(0); ++ } ++ } ++ for (i = 0; i < num_children; i++) { ++ SCX_FAIL_IF(waitpid(pids[i], &status, 0) != pids[i], ++ "Failed to wait for SCX child\n"); ++ ++ SCX_FAIL_IF(status != 0, "SCX child %d exited with status %d\n", i, ++ status); ++ } ++ ++ /* SCHED_OTHER children */ ++ for (i = 0; i < num_children; i++) { ++ pids[i] = fork(); ++ if (pids[i] == 0) ++ exit(0); ++ } ++ ++ for (i = 0; i < num_children; i++) { ++ SCX_FAIL_IF(waitpid(pids[i], &status, 0) != pids[i], ++ "Failed to wait for normal child\n"); ++ ++ SCX_FAIL_IF(status != 0, "Normal child %d exited with status %d\n", i, ++ status); ++ } ++ ++ bpf_link__destroy(link); ++ ++ SCX_GE(skel->bss->init_task_cnt, 2 * num_children); ++ SCX_GE(skel->bss->exit_task_cnt, 2 * num_children); ++ ++ if (global) { ++ SCX_GE(skel->bss->enable_cnt, 2 * num_children); ++ SCX_GE(skel->bss->disable_cnt, 2 * num_children); ++ } else { ++ SCX_EQ(skel->bss->enable_cnt, num_children); ++ SCX_EQ(skel->bss->disable_cnt, num_children); ++ } ++ /* ++ * We forked a ton of tasks before we attached the scheduler above, so ++ * this should be fine. Technically it could be flaky if a ton of forks ++ * are happening at the same time in other processes, but that should ++ * be exceedingly unlikely. ++ */ ++ SCX_GT(skel->bss->init_transition_cnt, skel->bss->init_fork_cnt); ++ SCX_GE(skel->bss->init_fork_cnt, 2 * num_children); ++ ++ init_enable_count__destroy(skel); ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ enum scx_test_status status; ++ ++ status = run_test(true); ++ if (status != SCX_TEST_PASS) ++ return status; ++ ++ return run_test(false); ++} ++ ++struct scx_test init_enable_count = { ++ .name = "init_enable_count", ++ .description = "Verify we do the correct amount of counting of init, " ++ "enable, etc callbacks.", ++ .run = run, ++}; ++REGISTER_SCX_TEST(&init_enable_count) +diff --git a/tools/testing/selftests/sched_ext/maximal.bpf.c b/tools/testing/selftests/sched_ext/maximal.bpf.c +new file mode 100644 +index 000000000000..44612fdaf399 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/maximal.bpf.c +@@ -0,0 +1,132 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler with every callback defined. ++ * ++ * This scheduler defines every callback. ++ * ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++s32 BPF_STRUCT_OPS(maximal_select_cpu, struct task_struct *p, s32 prev_cpu, ++ u64 wake_flags) ++{ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(maximal_enqueue, struct task_struct *p, u64 enq_flags) ++{ ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, enq_flags); ++} ++ ++void BPF_STRUCT_OPS(maximal_dequeue, struct task_struct *p, u64 deq_flags) ++{} ++ ++void BPF_STRUCT_OPS(maximal_dispatch, s32 cpu, struct task_struct *prev) ++{ ++ scx_bpf_consume(SCX_DSQ_GLOBAL); ++} ++ ++void BPF_STRUCT_OPS(maximal_runnable, struct task_struct *p, u64 enq_flags) ++{} ++ ++void BPF_STRUCT_OPS(maximal_running, struct task_struct *p) ++{} ++ ++void BPF_STRUCT_OPS(maximal_stopping, struct task_struct *p, bool runnable) ++{} ++ ++void BPF_STRUCT_OPS(maximal_quiescent, struct task_struct *p, u64 deq_flags) ++{} ++ ++bool BPF_STRUCT_OPS(maximal_yield, struct task_struct *from, ++ struct task_struct *to) ++{ ++ return false; ++} ++ ++bool BPF_STRUCT_OPS(maximal_core_sched_before, struct task_struct *a, ++ struct task_struct *b) ++{ ++ return false; ++} ++ ++void BPF_STRUCT_OPS(maximal_set_weight, struct task_struct *p, u32 weight) ++{} ++ ++void BPF_STRUCT_OPS(maximal_set_cpumask, struct task_struct *p, ++ const struct cpumask *cpumask) ++{} ++ ++void BPF_STRUCT_OPS(maximal_update_idle, s32 cpu, bool idle) ++{} ++ ++void BPF_STRUCT_OPS(maximal_cpu_acquire, s32 cpu, ++ struct scx_cpu_acquire_args *args) ++{} ++ ++void BPF_STRUCT_OPS(maximal_cpu_release, s32 cpu, ++ struct scx_cpu_release_args *args) ++{} ++ ++void BPF_STRUCT_OPS(maximal_cpu_online, s32 cpu) ++{} ++ ++void BPF_STRUCT_OPS(maximal_cpu_offline, s32 cpu) ++{} ++ ++s32 BPF_STRUCT_OPS(maximal_init_task, struct task_struct *p, ++ struct scx_init_task_args *args) ++{ ++ return 0; ++} ++ ++void BPF_STRUCT_OPS(maximal_enable, struct task_struct *p) ++{} ++ ++void BPF_STRUCT_OPS(maximal_exit_task, struct task_struct *p, ++ struct scx_exit_task_args *args) ++{} ++ ++void BPF_STRUCT_OPS(maximal_disable, struct task_struct *p) ++{} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(maximal_init) ++{ ++ return 0; ++} ++ ++void BPF_STRUCT_OPS(maximal_exit, struct scx_exit_info *info) ++{} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops maximal_ops = { ++ .select_cpu = maximal_select_cpu, ++ .enqueue = maximal_enqueue, ++ .dequeue = maximal_dequeue, ++ .dispatch = maximal_dispatch, ++ .runnable = maximal_runnable, ++ .running = maximal_running, ++ .stopping = maximal_stopping, ++ .quiescent = maximal_quiescent, ++ .yield = maximal_yield, ++ .core_sched_before = maximal_core_sched_before, ++ .set_weight = maximal_set_weight, ++ .set_cpumask = maximal_set_cpumask, ++ .update_idle = maximal_update_idle, ++ .cpu_acquire = maximal_cpu_acquire, ++ .cpu_release = maximal_cpu_release, ++ .cpu_online = maximal_cpu_online, ++ .cpu_offline = maximal_cpu_offline, ++ .init_task = maximal_init_task, ++ .enable = maximal_enable, ++ .exit_task = maximal_exit_task, ++ .disable = maximal_disable, ++ .init = maximal_init, ++ .exit = maximal_exit, ++ .name = "maximal", ++}; +diff --git a/tools/testing/selftests/sched_ext/maximal.c b/tools/testing/selftests/sched_ext/maximal.c +new file mode 100644 +index 000000000000..f38fc973c380 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/maximal.c +@@ -0,0 +1,51 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include "maximal.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct maximal *skel; ++ ++ skel = maximal__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct maximal *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.maximal_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct maximal *skel = ctx; ++ ++ maximal__destroy(skel); ++} ++ ++struct scx_test maximal = { ++ .name = "maximal", ++ .description = "Verify we can load a scheduler with every callback defined", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&maximal) +diff --git a/tools/testing/selftests/sched_ext/maybe_null.bpf.c b/tools/testing/selftests/sched_ext/maybe_null.bpf.c +new file mode 100644 +index 000000000000..27d0f386acfb +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/maybe_null.bpf.c +@@ -0,0 +1,36 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++u64 vtime_test; ++ ++void BPF_STRUCT_OPS(maybe_null_running, struct task_struct *p) ++{} ++ ++void BPF_STRUCT_OPS(maybe_null_success_dispatch, s32 cpu, struct task_struct *p) ++{ ++ if (p != NULL) ++ vtime_test = p->scx.dsq_vtime; ++} ++ ++bool BPF_STRUCT_OPS(maybe_null_success_yield, struct task_struct *from, ++ struct task_struct *to) ++{ ++ if (to) ++ bpf_printk("Yielding to %s[%d]", to->comm, to->pid); ++ ++ return false; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops maybe_null_success = { ++ .dispatch = maybe_null_success_dispatch, ++ .yield = maybe_null_success_yield, ++ .enable = maybe_null_running, ++ .name = "minimal", ++}; +diff --git a/tools/testing/selftests/sched_ext/maybe_null.c b/tools/testing/selftests/sched_ext/maybe_null.c +new file mode 100644 +index 000000000000..31cfafb0cf65 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/maybe_null.c +@@ -0,0 +1,49 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ */ ++#include ++#include ++#include ++#include ++#include "maybe_null.bpf.skel.h" ++#include "maybe_null_fail_dsp.bpf.skel.h" ++#include "maybe_null_fail_yld.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct maybe_null *skel; ++ struct maybe_null_fail_dsp *fail_dsp; ++ struct maybe_null_fail_yld *fail_yld; ++ ++ skel = maybe_null__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load maybe_null skel"); ++ return SCX_TEST_FAIL; ++ } ++ maybe_null__destroy(skel); ++ ++ fail_dsp = maybe_null_fail_dsp__open_and_load(); ++ if (fail_dsp) { ++ maybe_null_fail_dsp__destroy(fail_dsp); ++ SCX_ERR("Should failed to open and load maybe_null_fail_dsp skel"); ++ return SCX_TEST_FAIL; ++ } ++ ++ fail_yld = maybe_null_fail_yld__open_and_load(); ++ if (fail_yld) { ++ maybe_null_fail_yld__destroy(fail_yld); ++ SCX_ERR("Should failed to open and load maybe_null_fail_yld skel"); ++ return SCX_TEST_FAIL; ++ } ++ ++ return SCX_TEST_PASS; ++} ++ ++struct scx_test maybe_null = { ++ .name = "maybe_null", ++ .description = "Verify if PTR_MAYBE_NULL work for .dispatch", ++ .run = run, ++}; ++REGISTER_SCX_TEST(&maybe_null) +diff --git a/tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c b/tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c +new file mode 100644 +index 000000000000..c0641050271d +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c +@@ -0,0 +1,25 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++u64 vtime_test; ++ ++void BPF_STRUCT_OPS(maybe_null_running, struct task_struct *p) ++{} ++ ++void BPF_STRUCT_OPS(maybe_null_fail_dispatch, s32 cpu, struct task_struct *p) ++{ ++ vtime_test = p->scx.dsq_vtime; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops maybe_null_fail = { ++ .dispatch = maybe_null_fail_dispatch, ++ .enable = maybe_null_running, ++ .name = "maybe_null_fail_dispatch", ++}; +diff --git a/tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c b/tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c +new file mode 100644 +index 000000000000..3c1740028e3b +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c +@@ -0,0 +1,28 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++u64 vtime_test; ++ ++void BPF_STRUCT_OPS(maybe_null_running, struct task_struct *p) ++{} ++ ++bool BPF_STRUCT_OPS(maybe_null_fail_yield, struct task_struct *from, ++ struct task_struct *to) ++{ ++ bpf_printk("Yielding to %s[%d]", to->comm, to->pid); ++ ++ return false; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops maybe_null_fail = { ++ .yield = maybe_null_fail_yield, ++ .enable = maybe_null_running, ++ .name = "maybe_null_fail_yield", ++}; +diff --git a/tools/testing/selftests/sched_ext/minimal.bpf.c b/tools/testing/selftests/sched_ext/minimal.bpf.c +new file mode 100644 +index 000000000000..6a7eccef0104 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/minimal.bpf.c +@@ -0,0 +1,21 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A completely minimal scheduler. ++ * ++ * This scheduler defines the absolute minimal set of struct sched_ext_ops ++ * fields: its name. It should _not_ fail to be loaded, and can be used to ++ * exercise the default scheduling paths in ext.c. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops minimal_ops = { ++ .name = "minimal", ++}; +diff --git a/tools/testing/selftests/sched_ext/minimal.c b/tools/testing/selftests/sched_ext/minimal.c +new file mode 100644 +index 000000000000..6c5db8ebbf8a +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/minimal.c +@@ -0,0 +1,58 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "minimal.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct minimal *skel; ++ ++ skel = minimal__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load skel"); ++ return SCX_TEST_FAIL; ++ } ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct minimal *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.minimal_ops); ++ if (!link) { ++ SCX_ERR("Failed to attach scheduler"); ++ return SCX_TEST_FAIL; ++ } ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct minimal *skel = ctx; ++ ++ minimal__destroy(skel); ++} ++ ++struct scx_test minimal = { ++ .name = "minimal", ++ .description = "Verify we can load a fully minimal scheduler", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&minimal) +diff --git a/tools/testing/selftests/sched_ext/prog_run.bpf.c b/tools/testing/selftests/sched_ext/prog_run.bpf.c +new file mode 100644 +index 000000000000..6a4d7c48e3f2 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/prog_run.bpf.c +@@ -0,0 +1,33 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates that we can invoke sched_ext kfuncs in ++ * BPF_PROG_TYPE_SYSCALL programs. ++ * ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#include ++ ++UEI_DEFINE(uei); ++ ++char _license[] SEC("license") = "GPL"; ++ ++SEC("syscall") ++int BPF_PROG(prog_run_syscall) ++{ ++ scx_bpf_create_dsq(0, -1); ++ scx_bpf_exit(0xdeadbeef, "Exited from PROG_RUN"); ++ return 0; ++} ++ ++void BPF_STRUCT_OPS(prog_run_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops prog_run_ops = { ++ .exit = prog_run_exit, ++ .name = "prog_run", ++}; +diff --git a/tools/testing/selftests/sched_ext/prog_run.c b/tools/testing/selftests/sched_ext/prog_run.c +new file mode 100644 +index 000000000000..3cd57ef8daaa +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/prog_run.c +@@ -0,0 +1,78 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include "prog_run.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct prog_run *skel; ++ ++ skel = prog_run__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load skel"); ++ return SCX_TEST_FAIL; ++ } ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct prog_run *skel = ctx; ++ struct bpf_link *link; ++ int prog_fd, err = 0; ++ ++ prog_fd = bpf_program__fd(skel->progs.prog_run_syscall); ++ if (prog_fd < 0) { ++ SCX_ERR("Failed to get BPF_PROG_RUN prog"); ++ return SCX_TEST_FAIL; ++ } ++ ++ LIBBPF_OPTS(bpf_test_run_opts, topts); ++ ++ link = bpf_map__attach_struct_ops(skel->maps.prog_run_ops); ++ if (!link) { ++ SCX_ERR("Failed to attach scheduler"); ++ close(prog_fd); ++ return SCX_TEST_FAIL; ++ } ++ ++ err = bpf_prog_test_run_opts(prog_fd, &topts); ++ SCX_EQ(err, 0); ++ ++ /* Assumes uei.kind is written last */ ++ while (skel->data->uei.kind == EXIT_KIND(SCX_EXIT_NONE)) ++ sched_yield(); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_UNREG_BPF)); ++ SCX_EQ(skel->data->uei.exit_code, 0xdeadbeef); ++ close(prog_fd); ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct prog_run *skel = ctx; ++ ++ prog_run__destroy(skel); ++} ++ ++struct scx_test prog_run = { ++ .name = "prog_run", ++ .description = "Verify we can call into a scheduler with BPF_PROG_RUN, and invoke kfuncs", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&prog_run) +diff --git a/tools/testing/selftests/sched_ext/reload_loop.c b/tools/testing/selftests/sched_ext/reload_loop.c +new file mode 100644 +index 000000000000..5cfba2d6e056 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/reload_loop.c +@@ -0,0 +1,75 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include "maximal.bpf.skel.h" ++#include "scx_test.h" ++ ++static struct maximal *skel; ++static pthread_t threads[2]; ++ ++bool force_exit = false; ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ skel = maximal__open_and_load(); ++ if (!skel) { ++ SCX_ERR("Failed to open and load skel"); ++ return SCX_TEST_FAIL; ++ } ++ ++ return SCX_TEST_PASS; ++} ++ ++static void *do_reload_loop(void *arg) ++{ ++ u32 i; ++ ++ for (i = 0; i < 1024 && !force_exit; i++) { ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.maximal_ops); ++ if (link) ++ bpf_link__destroy(link); ++ } ++ ++ return NULL; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ int err; ++ void *ret; ++ ++ err = pthread_create(&threads[0], NULL, do_reload_loop, NULL); ++ SCX_FAIL_IF(err, "Failed to create thread 0"); ++ ++ err = pthread_create(&threads[1], NULL, do_reload_loop, NULL); ++ SCX_FAIL_IF(err, "Failed to create thread 1"); ++ ++ SCX_FAIL_IF(pthread_join(threads[0], &ret), "thread 0 failed"); ++ SCX_FAIL_IF(pthread_join(threads[1], &ret), "thread 1 failed"); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ force_exit = true; ++ maximal__destroy(skel); ++} ++ ++struct scx_test reload_loop = { ++ .name = "reload_loop", ++ .description = "Stress test loading and unloading schedulers repeatedly in a tight loop", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&reload_loop) +diff --git a/tools/testing/selftests/sched_ext/runner.c b/tools/testing/selftests/sched_ext/runner.c +new file mode 100644 +index 000000000000..eab48c7ff309 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/runner.c +@@ -0,0 +1,201 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include ++#include "scx_test.h" ++ ++const char help_fmt[] = ++"The runner for sched_ext tests.\n" ++"\n" ++"The runner is statically linked against all testcases, and runs them all serially.\n" ++"It's required for the testcases to be serial, as only a single host-wide sched_ext\n" ++"scheduler may be loaded at any given time." ++"\n" ++"Usage: %s [-t TEST] [-h]\n" ++"\n" ++" -t TEST Only run tests whose name includes this string\n" ++" -s Include print output for skipped tests\n" ++" -q Don't print the test descriptions during run\n" ++" -h Display this help and exit\n"; ++ ++static volatile int exit_req; ++static bool quiet, print_skipped; ++ ++#define MAX_SCX_TESTS 2048 ++ ++static struct scx_test __scx_tests[MAX_SCX_TESTS]; ++static unsigned __scx_num_tests = 0; ++ ++static void sigint_handler(int simple) ++{ ++ exit_req = 1; ++} ++ ++static void print_test_preamble(const struct scx_test *test, bool quiet) ++{ ++ printf("===== START =====\n"); ++ printf("TEST: %s\n", test->name); ++ if (!quiet) ++ printf("DESCRIPTION: %s\n", test->description); ++ printf("OUTPUT:\n"); ++} ++ ++static const char *status_to_result(enum scx_test_status status) ++{ ++ switch (status) { ++ case SCX_TEST_PASS: ++ case SCX_TEST_SKIP: ++ return "ok"; ++ case SCX_TEST_FAIL: ++ return "not ok"; ++ default: ++ return ""; ++ } ++} ++ ++static void print_test_result(const struct scx_test *test, ++ enum scx_test_status status, ++ unsigned int testnum) ++{ ++ const char *result = status_to_result(status); ++ const char *directive = status == SCX_TEST_SKIP ? "SKIP " : ""; ++ ++ printf("%s %u %s # %s\n", result, testnum, test->name, directive); ++ printf("===== END =====\n"); ++} ++ ++static bool should_skip_test(const struct scx_test *test, const char * filter) ++{ ++ return !strstr(test->name, filter); ++} ++ ++static enum scx_test_status run_test(const struct scx_test *test) ++{ ++ enum scx_test_status status; ++ void *context = NULL; ++ ++ if (test->setup) { ++ status = test->setup(&context); ++ if (status != SCX_TEST_PASS) ++ return status; ++ } ++ ++ status = test->run(context); ++ ++ if (test->cleanup) ++ test->cleanup(context); ++ ++ return status; ++} ++ ++static bool test_valid(const struct scx_test *test) ++{ ++ if (!test) { ++ fprintf(stderr, "NULL test detected\n"); ++ return false; ++ } ++ ++ if (!test->name) { ++ fprintf(stderr, ++ "Test with no name found. Must specify test name.\n"); ++ return false; ++ } ++ ++ if (!test->description) { ++ fprintf(stderr, "Test %s requires description.\n", test->name); ++ return false; ++ } ++ ++ if (!test->run) { ++ fprintf(stderr, "Test %s has no run() callback\n", test->name); ++ return false; ++ } ++ ++ return true; ++} ++ ++int main(int argc, char **argv) ++{ ++ const char *filter = NULL; ++ unsigned testnum = 0, i; ++ unsigned passed = 0, skipped = 0, failed = 0; ++ int opt; ++ ++ signal(SIGINT, sigint_handler); ++ signal(SIGTERM, sigint_handler); ++ ++ libbpf_set_strict_mode(LIBBPF_STRICT_ALL); ++ ++ while ((opt = getopt(argc, argv, "qst:h")) != -1) { ++ switch (opt) { ++ case 'q': ++ quiet = true; ++ break; ++ case 's': ++ print_skipped = true; ++ break; ++ case 't': ++ filter = optarg; ++ break; ++ default: ++ fprintf(stderr, help_fmt, basename(argv[0])); ++ return opt != 'h'; ++ } ++ } ++ ++ for (i = 0; i < __scx_num_tests; i++) { ++ enum scx_test_status status; ++ struct scx_test *test = &__scx_tests[i]; ++ ++ if (filter && should_skip_test(test, filter)) { ++ /* ++ * Printing the skipped tests and their preambles can ++ * add a lot of noise to the runner output. Printing ++ * this is only really useful for CI, so let's skip it ++ * by default. ++ */ ++ if (print_skipped) { ++ print_test_preamble(test, quiet); ++ print_test_result(test, SCX_TEST_SKIP, ++testnum); ++ } ++ continue; ++ } ++ ++ print_test_preamble(test, quiet); ++ status = run_test(test); ++ print_test_result(test, status, ++testnum); ++ switch (status) { ++ case SCX_TEST_PASS: ++ passed++; ++ break; ++ case SCX_TEST_SKIP: ++ skipped++; ++ break; ++ case SCX_TEST_FAIL: ++ failed++; ++ break; ++ } ++ } ++ printf("\n\n=============================\n\n"); ++ printf("RESULTS:\n\n"); ++ printf("PASSED: %u\n", passed); ++ printf("SKIPPED: %u\n", skipped); ++ printf("FAILED: %u\n", failed); ++ ++ return 0; ++} ++ ++void scx_test_register(struct scx_test *test) ++{ ++ SCX_BUG_ON(!test_valid(test), "Invalid test found"); ++ SCX_BUG_ON(__scx_num_tests >= MAX_SCX_TESTS, "Maximum tests exceeded"); ++ ++ __scx_tests[__scx_num_tests++] = *test; ++} +diff --git a/tools/testing/selftests/sched_ext/scx_test.h b/tools/testing/selftests/sched_ext/scx_test.h +new file mode 100644 +index 000000000000..90b8d6915bb7 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/scx_test.h +@@ -0,0 +1,131 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 Tejun Heo ++ * Copyright (c) 2023 David Vernet ++ */ ++ ++#ifndef __SCX_TEST_H__ ++#define __SCX_TEST_H__ ++ ++#include ++#include ++#include ++ ++enum scx_test_status { ++ SCX_TEST_PASS = 0, ++ SCX_TEST_SKIP, ++ SCX_TEST_FAIL, ++}; ++ ++#define EXIT_KIND(__ent) __COMPAT_ENUM_OR_ZERO("scx_exit_kind", #__ent) ++ ++struct scx_test { ++ /** ++ * name - The name of the testcase. ++ */ ++ const char *name; ++ ++ /** ++ * description - A description of your testcase: what it tests and is ++ * meant to validate. ++ */ ++ const char *description; ++ ++ /* ++ * setup - Setup the test. ++ * @ctx: A pointer to a context object that will be passed to run and ++ * cleanup. ++ * ++ * An optional callback that allows a testcase to perform setup for its ++ * run. A test may return SCX_TEST_SKIP to skip the run. ++ */ ++ enum scx_test_status (*setup)(void **ctx); ++ ++ /* ++ * run - Run the test. ++ * @ctx: Context set in the setup() callback. If @ctx was not set in ++ * setup(), it is NULL. ++ * ++ * The main test. Callers should return one of: ++ * ++ * - SCX_TEST_PASS: Test passed ++ * - SCX_TEST_SKIP: Test should be skipped ++ * - SCX_TEST_FAIL: Test failed ++ * ++ * This callback must be defined. ++ */ ++ enum scx_test_status (*run)(void *ctx); ++ ++ /* ++ * cleanup - Perform cleanup following the test ++ * @ctx: Context set in the setup() callback. If @ctx was not set in ++ * setup(), it is NULL. ++ * ++ * An optional callback that allows a test to perform cleanup after ++ * being run. This callback is run even if the run() callback returns ++ * SCX_TEST_SKIP or SCX_TEST_FAIL. It is not run if setup() returns ++ * SCX_TEST_SKIP or SCX_TEST_FAIL. ++ */ ++ void (*cleanup)(void *ctx); ++}; ++ ++void scx_test_register(struct scx_test *test); ++ ++#define REGISTER_SCX_TEST(__test) \ ++ __attribute__((constructor)) \ ++ static void ___scxregister##__LINE__(void) \ ++ { \ ++ scx_test_register(__test); \ ++ } ++ ++#define SCX_ERR(__fmt, ...) \ ++ do { \ ++ fprintf(stderr, "ERR: %s:%d\n", __FILE__, __LINE__); \ ++ fprintf(stderr, __fmt"\n", ##__VA_ARGS__); \ ++ } while (0) ++ ++#define SCX_FAIL(__fmt, ...) \ ++ do { \ ++ SCX_ERR(__fmt, ##__VA_ARGS__); \ ++ return SCX_TEST_FAIL; \ ++ } while (0) ++ ++#define SCX_FAIL_IF(__cond, __fmt, ...) \ ++ do { \ ++ if (__cond) \ ++ SCX_FAIL(__fmt, ##__VA_ARGS__); \ ++ } while (0) ++ ++#define SCX_GT(_x, _y) SCX_FAIL_IF((_x) <= (_y), "Expected %s > %s (%lu > %lu)", \ ++ #_x, #_y, (u64)(_x), (u64)(_y)) ++#define SCX_GE(_x, _y) SCX_FAIL_IF((_x) < (_y), "Expected %s >= %s (%lu >= %lu)", \ ++ #_x, #_y, (u64)(_x), (u64)(_y)) ++#define SCX_LT(_x, _y) SCX_FAIL_IF((_x) >= (_y), "Expected %s < %s (%lu < %lu)", \ ++ #_x, #_y, (u64)(_x), (u64)(_y)) ++#define SCX_LE(_x, _y) SCX_FAIL_IF((_x) > (_y), "Expected %s <= %s (%lu <= %lu)", \ ++ #_x, #_y, (u64)(_x), (u64)(_y)) ++#define SCX_EQ(_x, _y) SCX_FAIL_IF((_x) != (_y), "Expected %s == %s (%lu == %lu)", \ ++ #_x, #_y, (u64)(_x), (u64)(_y)) ++#define SCX_ASSERT(_x) SCX_FAIL_IF(!(_x), "Expected %s to be true (%lu)", \ ++ #_x, (u64)(_x)) ++ ++#define SCX_ECODE_VAL(__ecode) ({ \ ++ u64 __val = 0; \ ++ bool __found = false; \ ++ \ ++ __found = __COMPAT_read_enum("scx_exit_code", #__ecode, &__val); \ ++ SCX_ASSERT(__found); \ ++ (s64)__val; \ ++}) ++ ++#define SCX_KIND_VAL(__kind) ({ \ ++ u64 __val = 0; \ ++ bool __found = false; \ ++ \ ++ __found = __COMPAT_read_enum("scx_exit_kind", #__kind, &__val); \ ++ SCX_ASSERT(__found); \ ++ __val; \ ++}) ++ ++#endif // # __SCX_TEST_H__ +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dfl.bpf.c b/tools/testing/selftests/sched_ext/select_cpu_dfl.bpf.c +new file mode 100644 +index 000000000000..2ed2991afafe +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dfl.bpf.c +@@ -0,0 +1,40 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates the behavior of direct dispatching with a default ++ * select_cpu implementation. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++bool saw_local = false; ++ ++static bool task_is_test(const struct task_struct *p) ++{ ++ return !bpf_strncmp(p->comm, 9, "select_cpu"); ++} ++ ++void BPF_STRUCT_OPS(select_cpu_dfl_enqueue, struct task_struct *p, ++ u64 enq_flags) ++{ ++ const struct cpumask *idle_mask = scx_bpf_get_idle_cpumask(); ++ ++ if (task_is_test(p) && ++ bpf_cpumask_test_cpu(scx_bpf_task_cpu(p), idle_mask)) { ++ saw_local = true; ++ } ++ scx_bpf_put_idle_cpumask(idle_mask); ++ ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, enq_flags); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops select_cpu_dfl_ops = { ++ .enqueue = select_cpu_dfl_enqueue, ++ .name = "select_cpu_dfl", ++}; +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dfl.c b/tools/testing/selftests/sched_ext/select_cpu_dfl.c +new file mode 100644 +index 000000000000..a53a40c2d2f0 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dfl.c +@@ -0,0 +1,72 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "select_cpu_dfl.bpf.skel.h" ++#include "scx_test.h" ++ ++#define NUM_CHILDREN 1028 ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct select_cpu_dfl *skel; ++ ++ skel = select_cpu_dfl__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct select_cpu_dfl *skel = ctx; ++ struct bpf_link *link; ++ pid_t pids[NUM_CHILDREN]; ++ int i, status; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.select_cpu_dfl_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ for (i = 0; i < NUM_CHILDREN; i++) { ++ pids[i] = fork(); ++ if (pids[i] == 0) { ++ sleep(1); ++ exit(0); ++ } ++ } ++ ++ for (i = 0; i < NUM_CHILDREN; i++) { ++ SCX_EQ(waitpid(pids[i], &status, 0), pids[i]); ++ SCX_EQ(status, 0); ++ } ++ ++ SCX_ASSERT(!skel->bss->saw_local); ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct select_cpu_dfl *skel = ctx; ++ ++ select_cpu_dfl__destroy(skel); ++} ++ ++struct scx_test select_cpu_dfl = { ++ .name = "select_cpu_dfl", ++ .description = "Verify the default ops.select_cpu() dispatches tasks " ++ "when idles cores are found, and skips ops.enqueue()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&select_cpu_dfl) +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.bpf.c b/tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.bpf.c +new file mode 100644 +index 000000000000..4bb5abb2d369 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.bpf.c +@@ -0,0 +1,89 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates the behavior of direct dispatching with a default ++ * select_cpu implementation, and with the SCX_OPS_ENQ_DFL_NO_DISPATCH ops flag ++ * specified. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++bool saw_local = false; ++ ++/* Per-task scheduling context */ ++struct task_ctx { ++ bool force_local; /* CPU changed by ops.select_cpu() */ ++}; ++ ++struct { ++ __uint(type, BPF_MAP_TYPE_TASK_STORAGE); ++ __uint(map_flags, BPF_F_NO_PREALLOC); ++ __type(key, int); ++ __type(value, struct task_ctx); ++} task_ctx_stor SEC(".maps"); ++ ++/* Manually specify the signature until the kfunc is added to the scx repo. */ ++s32 scx_bpf_select_cpu_dfl(struct task_struct *p, s32 prev_cpu, u64 wake_flags, ++ bool *found) __ksym; ++ ++s32 BPF_STRUCT_OPS(select_cpu_dfl_nodispatch_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ struct task_ctx *tctx; ++ s32 cpu; ++ ++ tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0); ++ if (!tctx) { ++ scx_bpf_error("task_ctx lookup failed"); ++ return -ESRCH; ++ } ++ ++ cpu = scx_bpf_select_cpu_dfl(p, prev_cpu, wake_flags, ++ &tctx->force_local); ++ ++ return cpu; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_dfl_nodispatch_enqueue, struct task_struct *p, ++ u64 enq_flags) ++{ ++ u64 dsq_id = SCX_DSQ_GLOBAL; ++ struct task_ctx *tctx; ++ ++ tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0); ++ if (!tctx) { ++ scx_bpf_error("task_ctx lookup failed"); ++ return; ++ } ++ ++ if (tctx->force_local) { ++ dsq_id = SCX_DSQ_LOCAL; ++ tctx->force_local = false; ++ saw_local = true; ++ } ++ ++ scx_bpf_dispatch(p, dsq_id, SCX_SLICE_DFL, enq_flags); ++} ++ ++s32 BPF_STRUCT_OPS(select_cpu_dfl_nodispatch_init_task, ++ struct task_struct *p, struct scx_init_task_args *args) ++{ ++ if (bpf_task_storage_get(&task_ctx_stor, p, 0, ++ BPF_LOCAL_STORAGE_GET_F_CREATE)) ++ return 0; ++ else ++ return -ENOMEM; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops select_cpu_dfl_nodispatch_ops = { ++ .select_cpu = select_cpu_dfl_nodispatch_select_cpu, ++ .enqueue = select_cpu_dfl_nodispatch_enqueue, ++ .init_task = select_cpu_dfl_nodispatch_init_task, ++ .name = "select_cpu_dfl_nodispatch", ++}; +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.c b/tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.c +new file mode 100644 +index 000000000000..1d85bf4bf3a3 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dfl_nodispatch.c +@@ -0,0 +1,72 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "select_cpu_dfl_nodispatch.bpf.skel.h" ++#include "scx_test.h" ++ ++#define NUM_CHILDREN 1028 ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct select_cpu_dfl_nodispatch *skel; ++ ++ skel = select_cpu_dfl_nodispatch__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct select_cpu_dfl_nodispatch *skel = ctx; ++ struct bpf_link *link; ++ pid_t pids[NUM_CHILDREN]; ++ int i, status; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.select_cpu_dfl_nodispatch_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ for (i = 0; i < NUM_CHILDREN; i++) { ++ pids[i] = fork(); ++ if (pids[i] == 0) { ++ sleep(1); ++ exit(0); ++ } ++ } ++ ++ for (i = 0; i < NUM_CHILDREN; i++) { ++ SCX_EQ(waitpid(pids[i], &status, 0), pids[i]); ++ SCX_EQ(status, 0); ++ } ++ ++ SCX_ASSERT(skel->bss->saw_local); ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct select_cpu_dfl_nodispatch *skel = ctx; ++ ++ select_cpu_dfl_nodispatch__destroy(skel); ++} ++ ++struct scx_test select_cpu_dfl_nodispatch = { ++ .name = "select_cpu_dfl_nodispatch", ++ .description = "Verify behavior of scx_bpf_select_cpu_dfl() in " ++ "ops.select_cpu()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&select_cpu_dfl_nodispatch) +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dispatch.bpf.c b/tools/testing/selftests/sched_ext/select_cpu_dispatch.bpf.c +new file mode 100644 +index 000000000000..f0b96a4a04b2 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dispatch.bpf.c +@@ -0,0 +1,41 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates the behavior of direct dispatching with a default ++ * select_cpu implementation. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++s32 BPF_STRUCT_OPS(select_cpu_dispatch_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ u64 dsq_id = SCX_DSQ_LOCAL; ++ s32 cpu = prev_cpu; ++ ++ if (scx_bpf_test_and_clear_cpu_idle(cpu)) ++ goto dispatch; ++ ++ cpu = scx_bpf_pick_idle_cpu(p->cpus_ptr, 0); ++ if (cpu >= 0) ++ goto dispatch; ++ ++ dsq_id = SCX_DSQ_GLOBAL; ++ cpu = prev_cpu; ++ ++dispatch: ++ scx_bpf_dispatch(p, dsq_id, SCX_SLICE_DFL, 0); ++ return cpu; ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops select_cpu_dispatch_ops = { ++ .select_cpu = select_cpu_dispatch_select_cpu, ++ .name = "select_cpu_dispatch", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dispatch.c b/tools/testing/selftests/sched_ext/select_cpu_dispatch.c +new file mode 100644 +index 000000000000..0309ca8785b3 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dispatch.c +@@ -0,0 +1,70 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "select_cpu_dispatch.bpf.skel.h" ++#include "scx_test.h" ++ ++#define NUM_CHILDREN 1028 ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct select_cpu_dispatch *skel; ++ ++ skel = select_cpu_dispatch__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct select_cpu_dispatch *skel = ctx; ++ struct bpf_link *link; ++ pid_t pids[NUM_CHILDREN]; ++ int i, status; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.select_cpu_dispatch_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ for (i = 0; i < NUM_CHILDREN; i++) { ++ pids[i] = fork(); ++ if (pids[i] == 0) { ++ sleep(1); ++ exit(0); ++ } ++ } ++ ++ for (i = 0; i < NUM_CHILDREN; i++) { ++ SCX_EQ(waitpid(pids[i], &status, 0), pids[i]); ++ SCX_EQ(status, 0); ++ } ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct select_cpu_dispatch *skel = ctx; ++ ++ select_cpu_dispatch__destroy(skel); ++} ++ ++struct scx_test select_cpu_dispatch = { ++ .name = "select_cpu_dispatch", ++ .description = "Test direct dispatching to built-in DSQs from " ++ "ops.select_cpu()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&select_cpu_dispatch) +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.bpf.c b/tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.bpf.c +new file mode 100644 +index 000000000000..7b42ddce0f56 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.bpf.c +@@ -0,0 +1,37 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates the behavior of direct dispatching with a default ++ * select_cpu implementation. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++UEI_DEFINE(uei); ++ ++s32 BPF_STRUCT_OPS(select_cpu_dispatch_bad_dsq_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ /* Dispatching to a random DSQ should fail. */ ++ scx_bpf_dispatch(p, 0xcafef00d, SCX_SLICE_DFL, 0); ++ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_dispatch_bad_dsq_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops select_cpu_dispatch_bad_dsq_ops = { ++ .select_cpu = select_cpu_dispatch_bad_dsq_select_cpu, ++ .exit = select_cpu_dispatch_bad_dsq_exit, ++ .name = "select_cpu_dispatch_bad_dsq", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.c b/tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.c +new file mode 100644 +index 000000000000..47eb6ed7627d +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dispatch_bad_dsq.c +@@ -0,0 +1,56 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "select_cpu_dispatch_bad_dsq.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct select_cpu_dispatch_bad_dsq *skel; ++ ++ skel = select_cpu_dispatch_bad_dsq__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct select_cpu_dispatch_bad_dsq *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.select_cpu_dispatch_bad_dsq_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ sleep(1); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_ERROR)); ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct select_cpu_dispatch_bad_dsq *skel = ctx; ++ ++ select_cpu_dispatch_bad_dsq__destroy(skel); ++} ++ ++struct scx_test select_cpu_dispatch_bad_dsq = { ++ .name = "select_cpu_dispatch_bad_dsq", ++ .description = "Verify graceful failure if we direct-dispatch to a " ++ "bogus DSQ in ops.select_cpu()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&select_cpu_dispatch_bad_dsq) +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.bpf.c b/tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.bpf.c +new file mode 100644 +index 000000000000..653e3dc0b4dc +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.bpf.c +@@ -0,0 +1,38 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates the behavior of direct dispatching with a default ++ * select_cpu implementation. ++ * ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++UEI_DEFINE(uei); ++ ++s32 BPF_STRUCT_OPS(select_cpu_dispatch_dbl_dsp_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ /* Dispatching twice in a row is disallowed. */ ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, 0); ++ scx_bpf_dispatch(p, SCX_DSQ_GLOBAL, SCX_SLICE_DFL, 0); ++ ++ return prev_cpu; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_dispatch_dbl_dsp_exit, struct scx_exit_info *ei) ++{ ++ UEI_RECORD(uei, ei); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops select_cpu_dispatch_dbl_dsp_ops = { ++ .select_cpu = select_cpu_dispatch_dbl_dsp_select_cpu, ++ .exit = select_cpu_dispatch_dbl_dsp_exit, ++ .name = "select_cpu_dispatch_dbl_dsp", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.c b/tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.c +new file mode 100644 +index 000000000000..48ff028a3c46 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_dispatch_dbl_dsp.c +@@ -0,0 +1,56 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2023 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2023 David Vernet ++ * Copyright (c) 2023 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "select_cpu_dispatch_dbl_dsp.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct select_cpu_dispatch_dbl_dsp *skel; ++ ++ skel = select_cpu_dispatch_dbl_dsp__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct select_cpu_dispatch_dbl_dsp *skel = ctx; ++ struct bpf_link *link; ++ ++ link = bpf_map__attach_struct_ops(skel->maps.select_cpu_dispatch_dbl_dsp_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ sleep(1); ++ ++ SCX_EQ(skel->data->uei.kind, EXIT_KIND(SCX_EXIT_ERROR)); ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct select_cpu_dispatch_dbl_dsp *skel = ctx; ++ ++ select_cpu_dispatch_dbl_dsp__destroy(skel); ++} ++ ++struct scx_test select_cpu_dispatch_dbl_dsp = { ++ .name = "select_cpu_dispatch_dbl_dsp", ++ .description = "Verify graceful failure if we dispatch twice to a " ++ "DSQ in ops.select_cpu()", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&select_cpu_dispatch_dbl_dsp) +diff --git a/tools/testing/selftests/sched_ext/select_cpu_vtime.bpf.c b/tools/testing/selftests/sched_ext/select_cpu_vtime.bpf.c +new file mode 100644 +index 000000000000..7f3ebf4fc2ea +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_vtime.bpf.c +@@ -0,0 +1,92 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * A scheduler that validates that enqueue flags are properly stored and ++ * applied at dispatch time when a task is directly dispatched from ++ * ops.select_cpu(). We validate this by using scx_bpf_dispatch_vtime(), and ++ * making the test a very basic vtime scheduler. ++ * ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++ ++#include ++ ++char _license[] SEC("license") = "GPL"; ++ ++volatile bool consumed; ++ ++static u64 vtime_now; ++ ++#define VTIME_DSQ 0 ++ ++static inline bool vtime_before(u64 a, u64 b) ++{ ++ return (s64)(a - b) < 0; ++} ++ ++static inline u64 task_vtime(const struct task_struct *p) ++{ ++ u64 vtime = p->scx.dsq_vtime; ++ ++ if (vtime_before(vtime, vtime_now - SCX_SLICE_DFL)) ++ return vtime_now - SCX_SLICE_DFL; ++ else ++ return vtime; ++} ++ ++s32 BPF_STRUCT_OPS(select_cpu_vtime_select_cpu, struct task_struct *p, ++ s32 prev_cpu, u64 wake_flags) ++{ ++ s32 cpu; ++ ++ cpu = scx_bpf_pick_idle_cpu(p->cpus_ptr, 0); ++ if (cpu >= 0) ++ goto ddsp; ++ ++ cpu = prev_cpu; ++ scx_bpf_test_and_clear_cpu_idle(cpu); ++ddsp: ++ scx_bpf_dispatch_vtime(p, VTIME_DSQ, SCX_SLICE_DFL, task_vtime(p), 0); ++ return cpu; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_vtime_dispatch, s32 cpu, struct task_struct *p) ++{ ++ if (scx_bpf_consume(VTIME_DSQ)) ++ consumed = true; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_vtime_running, struct task_struct *p) ++{ ++ if (vtime_before(vtime_now, p->scx.dsq_vtime)) ++ vtime_now = p->scx.dsq_vtime; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_vtime_stopping, struct task_struct *p, ++ bool runnable) ++{ ++ p->scx.dsq_vtime += (SCX_SLICE_DFL - p->scx.slice) * 100 / p->scx.weight; ++} ++ ++void BPF_STRUCT_OPS(select_cpu_vtime_enable, struct task_struct *p) ++{ ++ p->scx.dsq_vtime = vtime_now; ++} ++ ++s32 BPF_STRUCT_OPS_SLEEPABLE(select_cpu_vtime_init) ++{ ++ return scx_bpf_create_dsq(VTIME_DSQ, -1); ++} ++ ++SEC(".struct_ops.link") ++struct sched_ext_ops select_cpu_vtime_ops = { ++ .select_cpu = select_cpu_vtime_select_cpu, ++ .dispatch = select_cpu_vtime_dispatch, ++ .running = select_cpu_vtime_running, ++ .stopping = select_cpu_vtime_stopping, ++ .enable = select_cpu_vtime_enable, ++ .init = select_cpu_vtime_init, ++ .name = "select_cpu_vtime", ++ .timeout_ms = 1000U, ++}; +diff --git a/tools/testing/selftests/sched_ext/select_cpu_vtime.c b/tools/testing/selftests/sched_ext/select_cpu_vtime.c +new file mode 100644 +index 000000000000..b4629c2364f5 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/select_cpu_vtime.c +@@ -0,0 +1,59 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ * Copyright (c) 2024 Tejun Heo ++ */ ++#include ++#include ++#include ++#include ++#include "select_cpu_vtime.bpf.skel.h" ++#include "scx_test.h" ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ struct select_cpu_vtime *skel; ++ ++ skel = select_cpu_vtime__open_and_load(); ++ SCX_FAIL_IF(!skel, "Failed to open and load skel"); ++ *ctx = skel; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ struct select_cpu_vtime *skel = ctx; ++ struct bpf_link *link; ++ ++ SCX_ASSERT(!skel->bss->consumed); ++ ++ link = bpf_map__attach_struct_ops(skel->maps.select_cpu_vtime_ops); ++ SCX_FAIL_IF(!link, "Failed to attach scheduler"); ++ ++ sleep(1); ++ ++ SCX_ASSERT(skel->bss->consumed); ++ ++ bpf_link__destroy(link); ++ ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup(void *ctx) ++{ ++ struct select_cpu_vtime *skel = ctx; ++ ++ select_cpu_vtime__destroy(skel); ++} ++ ++struct scx_test select_cpu_vtime = { ++ .name = "select_cpu_vtime", ++ .description = "Test doing direct vtime-dispatching from " ++ "ops.select_cpu(), to a non-built-in DSQ", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&select_cpu_vtime) +diff --git a/tools/testing/selftests/sched_ext/test_example.c b/tools/testing/selftests/sched_ext/test_example.c +new file mode 100644 +index 000000000000..ce36cdf03cdc +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/test_example.c +@@ -0,0 +1,49 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 Tejun Heo ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include "scx_test.h" ++ ++static bool setup_called = false; ++static bool run_called = false; ++static bool cleanup_called = false; ++ ++static int context = 10; ++ ++static enum scx_test_status setup(void **ctx) ++{ ++ setup_called = true; ++ *ctx = &context; ++ ++ return SCX_TEST_PASS; ++} ++ ++static enum scx_test_status run(void *ctx) ++{ ++ int *arg = ctx; ++ ++ SCX_ASSERT(setup_called); ++ SCX_ASSERT(!run_called && !cleanup_called); ++ SCX_EQ(*arg, context); ++ ++ run_called = true; ++ return SCX_TEST_PASS; ++} ++ ++static void cleanup (void *ctx) ++{ ++ SCX_BUG_ON(!run_called || cleanup_called, "Wrong callbacks invoked"); ++} ++ ++struct scx_test example = { ++ .name = "example", ++ .description = "Validate the basic function of the test suite itself", ++ .setup = setup, ++ .run = run, ++ .cleanup = cleanup, ++}; ++REGISTER_SCX_TEST(&example) +diff --git a/tools/testing/selftests/sched_ext/util.c b/tools/testing/selftests/sched_ext/util.c +new file mode 100644 +index 000000000000..e47769c91918 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/util.c +@@ -0,0 +1,71 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++#include ++#include ++#include ++#include ++#include ++#include ++ ++/* Returns read len on success, or -errno on failure. */ ++static ssize_t read_text(const char *path, char *buf, size_t max_len) ++{ ++ ssize_t len; ++ int fd; ++ ++ fd = open(path, O_RDONLY); ++ if (fd < 0) ++ return -errno; ++ ++ len = read(fd, buf, max_len - 1); ++ ++ if (len >= 0) ++ buf[len] = 0; ++ ++ close(fd); ++ return len < 0 ? -errno : len; ++} ++ ++/* Returns written len on success, or -errno on failure. */ ++static ssize_t write_text(const char *path, char *buf, ssize_t len) ++{ ++ int fd; ++ ssize_t written; ++ ++ fd = open(path, O_WRONLY | O_APPEND); ++ if (fd < 0) ++ return -errno; ++ ++ written = write(fd, buf, len); ++ close(fd); ++ return written < 0 ? -errno : written; ++} ++ ++long file_read_long(const char *path) ++{ ++ char buf[128]; ++ ++ ++ if (read_text(path, buf, sizeof(buf)) <= 0) ++ return -1; ++ ++ return atol(buf); ++} ++ ++int file_write_long(const char *path, long val) ++{ ++ char buf[64]; ++ int ret; ++ ++ ret = sprintf(buf, "%lu", val); ++ if (ret < 0) ++ return ret; ++ ++ if (write_text(path, buf, sizeof(buf)) <= 0) ++ return -1; ++ ++ return 0; ++} +diff --git a/tools/testing/selftests/sched_ext/util.h b/tools/testing/selftests/sched_ext/util.h +new file mode 100644 +index 000000000000..bc13dfec1267 +--- /dev/null ++++ b/tools/testing/selftests/sched_ext/util.h +@@ -0,0 +1,13 @@ ++/* SPDX-License-Identifier: GPL-2.0 */ ++/* ++ * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. ++ * Copyright (c) 2024 David Vernet ++ */ ++ ++#ifndef __SCX_TEST_UTIL_H__ ++#define __SCX_TEST_UTIL_H__ ++ ++long file_read_long(const char *path); ++int file_write_long(const char *path, long val); ++ ++#endif // __SCX_TEST_H__ +-- +2.46.0.rc1 diff --git a/patches/0003-bore-cachy-ext.patch b/patches/0003-bore-cachy-ext.patch new file mode 100644 index 0000000..a8ffb44 --- /dev/null +++ b/patches/0003-bore-cachy-ext.patch @@ -0,0 +1,990 @@ +From 18455546cde636eec5bf35cd09962c6ab2938a88 Mon Sep 17 00:00:00 2001 +From: Piotr Gorski +Date: Mon, 12 Aug 2024 13:38:33 +0200 +Subject: [PATCH] bore-cachy-ext + +Signed-off-by: Piotr Gorski +--- + include/linux/sched.h | 10 ++ + init/Kconfig | 17 ++ + kernel/Kconfig.hz | 16 ++ + kernel/sched/core.c | 143 +++++++++++++++ + kernel/sched/debug.c | 60 ++++++- + kernel/sched/fair.c | 375 +++++++++++++++++++++++++++++++++++++--- + kernel/sched/features.h | 20 ++- + kernel/sched/sched.h | 7 + + 8 files changed, 621 insertions(+), 27 deletions(-) + +diff --git a/include/linux/sched.h b/include/linux/sched.h +index 0f3a107bc..247f28536 100644 +--- a/include/linux/sched.h ++++ b/include/linux/sched.h +@@ -549,6 +549,16 @@ struct sched_entity { + u64 sum_exec_runtime; + u64 prev_sum_exec_runtime; + u64 vruntime; ++#ifdef CONFIG_SCHED_BORE ++ u64 burst_time; ++ u8 prev_burst_penalty; ++ u8 curr_burst_penalty; ++ u8 burst_penalty; ++ u8 burst_score; ++ u8 child_burst; ++ u32 child_burst_cnt; ++ u64 child_burst_last_cached; ++#endif // CONFIG_SCHED_BORE + s64 vlag; + u64 slice; + +diff --git a/init/Kconfig b/init/Kconfig +index bfc033b53..2f7a9c534 100644 +--- a/init/Kconfig ++++ b/init/Kconfig +@@ -1303,6 +1303,23 @@ config CHECKPOINT_RESTORE + + If unsure, say N here. + ++config SCHED_BORE ++ bool "Burst-Oriented Response Enhancer" ++ default y ++ help ++ In Desktop and Mobile computing, one might prefer interactive ++ tasks to keep responsive no matter what they run in the background. ++ ++ Enabling this kernel feature modifies the scheduler to discriminate ++ tasks by their burst time (runtime since it last went sleeping or ++ yielding state) and prioritize those that run less bursty. ++ Such tasks usually include window compositor, widgets backend, ++ terminal emulator, video playback, games and so on. ++ With a little impact to scheduling fairness, it may improve ++ responsiveness especially under heavy background workload. ++ ++ If unsure, say Y here. ++ + config SCHED_AUTOGROUP + bool "Automatic process group scheduling" + select CGROUPS +diff --git a/kernel/Kconfig.hz b/kernel/Kconfig.hz +index 0f78364ef..b50189ee5 100644 +--- a/kernel/Kconfig.hz ++++ b/kernel/Kconfig.hz +@@ -79,5 +79,21 @@ config HZ + default 750 if HZ_750 + default 1000 if HZ_1000 + ++config MIN_BASE_SLICE_NS ++ int "Default value for min_base_slice_ns" ++ default 2000000 ++ help ++ The BORE Scheduler automatically calculates the optimal base ++ slice for the configured HZ using the following equation: ++ ++ base_slice_ns = max(min_base_slice_ns, 1000000000/HZ) ++ ++ This option sets the default lower bound limit of the base slice ++ to prevent the loss of task throughput due to overscheduling. ++ ++ Setting this value too high can cause the system to boot with ++ an unnecessarily large base slice, resulting in high scheduling ++ latency and poor system responsiveness. ++ + config SCHED_HRTICK + def_bool HIGH_RES_TIMERS +diff --git a/kernel/sched/core.c b/kernel/sched/core.c +index fb6276f74..95ead87fa 100644 +--- a/kernel/sched/core.c ++++ b/kernel/sched/core.c +@@ -4543,6 +4543,138 @@ int wake_up_state(struct task_struct *p, unsigned int state) + return try_to_wake_up(p, state, 0); + } + ++#ifdef CONFIG_SCHED_BORE ++extern u8 sched_burst_fork_atavistic; ++extern uint sched_burst_cache_lifetime; ++ ++static void __init sched_init_bore(void) { ++ init_task.se.burst_time = 0; ++ init_task.se.prev_burst_penalty = 0; ++ init_task.se.curr_burst_penalty = 0; ++ init_task.se.burst_penalty = 0; ++ init_task.se.burst_score = 0; ++ init_task.se.child_burst_last_cached = 0; ++} ++ ++inline void sched_fork_bore(struct task_struct *p) { ++ p->se.burst_time = 0; ++ p->se.curr_burst_penalty = 0; ++ p->se.burst_score = 0; ++ p->se.child_burst_last_cached = 0; ++} ++ ++static u32 count_child_tasks(struct task_struct *p) { ++ struct task_struct *child; ++ u32 cnt = 0; ++ list_for_each_entry(child, &p->children, sibling) {cnt++;} ++ return cnt; ++} ++ ++static inline bool task_is_inheritable(struct task_struct *p) { ++ return (p->sched_class == &fair_sched_class); ++} ++ ++static inline bool child_burst_cache_expired(struct task_struct *p, u64 now) { ++ u64 expiration_time = ++ p->se.child_burst_last_cached + sched_burst_cache_lifetime; ++ return ((s64)(expiration_time - now) < 0); ++} ++ ++static void __update_child_burst_cache( ++ struct task_struct *p, u32 cnt, u32 sum, u64 now) { ++ u8 avg = 0; ++ if (cnt) avg = sum / cnt; ++ p->se.child_burst = max(avg, p->se.burst_penalty); ++ p->se.child_burst_cnt = cnt; ++ p->se.child_burst_last_cached = now; ++} ++ ++static inline void update_child_burst_direct(struct task_struct *p, u64 now) { ++ struct task_struct *child; ++ u32 cnt = 0; ++ u32 sum = 0; ++ ++ list_for_each_entry(child, &p->children, sibling) { ++ if (!task_is_inheritable(child)) continue; ++ cnt++; ++ sum += child->se.burst_penalty; ++ } ++ ++ __update_child_burst_cache(p, cnt, sum, now); ++} ++ ++static inline u8 __inherit_burst_direct(struct task_struct *p, u64 now) { ++ struct task_struct *parent = p->real_parent; ++ if (child_burst_cache_expired(parent, now)) ++ update_child_burst_direct(parent, now); ++ ++ return parent->se.child_burst; ++} ++ ++static void update_child_burst_topological( ++ struct task_struct *p, u64 now, u32 depth, u32 *acnt, u32 *asum) { ++ struct task_struct *child, *dec; ++ u32 cnt = 0, dcnt = 0; ++ u32 sum = 0; ++ ++ list_for_each_entry(child, &p->children, sibling) { ++ dec = child; ++ while ((dcnt = count_child_tasks(dec)) == 1) ++ dec = list_first_entry(&dec->children, struct task_struct, sibling); ++ ++ if (!dcnt || !depth) { ++ if (!task_is_inheritable(dec)) continue; ++ cnt++; ++ sum += dec->se.burst_penalty; ++ continue; ++ } ++ if (!child_burst_cache_expired(dec, now)) { ++ cnt += dec->se.child_burst_cnt; ++ sum += (u32)dec->se.child_burst * dec->se.child_burst_cnt; ++ continue; ++ } ++ update_child_burst_topological(dec, now, depth - 1, &cnt, &sum); ++ } ++ ++ __update_child_burst_cache(p, cnt, sum, now); ++ *acnt += cnt; ++ *asum += sum; ++} ++ ++static inline u8 __inherit_burst_topological(struct task_struct *p, u64 now) { ++ struct task_struct *anc = p->real_parent; ++ u32 cnt = 0, sum = 0; ++ ++ while (anc->real_parent != anc && count_child_tasks(anc) == 1) ++ anc = anc->real_parent; ++ ++ if (child_burst_cache_expired(anc, now)) ++ update_child_burst_topological( ++ anc, now, sched_burst_fork_atavistic - 1, &cnt, &sum); ++ ++ return anc->se.child_burst; ++} ++ ++static inline void inherit_burst(struct task_struct *p) { ++ u8 burst_cache; ++ u64 now = ktime_get_ns(); ++ ++ read_lock(&tasklist_lock); ++ burst_cache = likely(sched_burst_fork_atavistic)? ++ __inherit_burst_topological(p, now): ++ __inherit_burst_direct(p, now); ++ read_unlock(&tasklist_lock); ++ ++ p->se.prev_burst_penalty = max(p->se.prev_burst_penalty, burst_cache); ++} ++ ++static void sched_post_fork_bore(struct task_struct *p) { ++ if (p->sched_class == &fair_sched_class) ++ inherit_burst(p); ++ p->se.burst_penalty = p->se.prev_burst_penalty; ++} ++#endif // CONFIG_SCHED_BORE ++ + /* + * Perform scheduler related setup for a newly forked process p. + * p is forked by current. +@@ -4559,6 +4691,9 @@ static void __sched_fork(unsigned long clone_flags, struct task_struct *p) + p->se.prev_sum_exec_runtime = 0; + p->se.nr_migrations = 0; + p->se.vruntime = 0; ++#ifdef CONFIG_SCHED_BORE ++ sched_fork_bore(p); ++#endif // CONFIG_SCHED_BORE + p->se.vlag = 0; + p->se.slice = sysctl_sched_base_slice; + INIT_LIST_HEAD(&p->se.group_node); +@@ -4893,6 +5028,9 @@ void sched_cancel_fork(struct task_struct *p) + + void sched_post_fork(struct task_struct *p) + { ++#ifdef CONFIG_SCHED_BORE ++ sched_post_fork_bore(p); ++#endif // CONFIG_SCHED_BORE + uclamp_post_fork(p); + scx_post_fork(p); + } +@@ -10044,6 +10182,11 @@ void __init sched_init(void) + BUG_ON(!sched_class_above(&ext_sched_class, &idle_sched_class)); + #endif + ++#ifdef CONFIG_SCHED_BORE ++ sched_init_bore(); ++ printk(KERN_INFO "BORE (Burst-Oriented Response Enhancer) CPU Scheduler modification 5.2.10 by Masahito Suzuki"); ++#endif // CONFIG_SCHED_BORE ++ + wait_bit_init(); + + #ifdef CONFIG_FAIR_GROUP_SCHED +diff --git a/kernel/sched/debug.c b/kernel/sched/debug.c +index c057ef46c..3cab39e34 100644 +--- a/kernel/sched/debug.c ++++ b/kernel/sched/debug.c +@@ -167,7 +167,52 @@ static const struct file_operations sched_feat_fops = { + }; + + #ifdef CONFIG_SMP ++#ifdef CONFIG_SCHED_BORE ++static ssize_t sched_min_base_slice_write(struct file *filp, const char __user *ubuf, ++ size_t cnt, loff_t *ppos) ++{ ++ char buf[16]; ++ unsigned int value; ++ ++ if (cnt > 15) ++ cnt = 15; ++ ++ if (copy_from_user(&buf, ubuf, cnt)) ++ return -EFAULT; ++ buf[cnt] = '\0'; ++ ++ if (kstrtouint(buf, 10, &value)) ++ return -EINVAL; + ++ if (!value) ++ return -EINVAL; ++ ++ sysctl_sched_min_base_slice = value; ++ sched_update_min_base_slice(); ++ ++ *ppos += cnt; ++ return cnt; ++} ++ ++static int sched_min_base_slice_show(struct seq_file *m, void *v) ++{ ++ seq_printf(m, "%d\n", sysctl_sched_min_base_slice); ++ return 0; ++} ++ ++static int sched_min_base_slice_open(struct inode *inode, struct file *filp) ++{ ++ return single_open(filp, sched_min_base_slice_show, NULL); ++} ++ ++static const struct file_operations sched_min_base_slice_fops = { ++ .open = sched_min_base_slice_open, ++ .write = sched_min_base_slice_write, ++ .read = seq_read, ++ .llseek = seq_lseek, ++ .release = single_release, ++}; ++#else // !CONFIG_SCHED_BORE + static ssize_t sched_scaling_write(struct file *filp, const char __user *ubuf, + size_t cnt, loff_t *ppos) + { +@@ -213,7 +258,7 @@ static const struct file_operations sched_scaling_fops = { + .llseek = seq_lseek, + .release = single_release, + }; +- ++#endif // CONFIG_SCHED_BORE + #endif /* SMP */ + + #ifdef CONFIG_PREEMPT_DYNAMIC +@@ -347,13 +392,20 @@ static __init int sched_init_debug(void) + debugfs_create_file("preempt", 0644, debugfs_sched, NULL, &sched_dynamic_fops); + #endif + ++#ifdef CONFIG_SCHED_BORE ++ debugfs_create_file("min_base_slice_ns", 0644, debugfs_sched, NULL, &sched_min_base_slice_fops); ++ debugfs_create_u32("base_slice_ns", 0400, debugfs_sched, &sysctl_sched_base_slice); ++#else // !CONFIG_SCHED_BORE + debugfs_create_u32("base_slice_ns", 0644, debugfs_sched, &sysctl_sched_base_slice); ++#endif // CONFIG_SCHED_BORE + + debugfs_create_u32("latency_warn_ms", 0644, debugfs_sched, &sysctl_resched_latency_warn_ms); + debugfs_create_u32("latency_warn_once", 0644, debugfs_sched, &sysctl_resched_latency_warn_once); + + #ifdef CONFIG_SMP ++#if !defined(CONFIG_SCHED_BORE) + debugfs_create_file("tunable_scaling", 0644, debugfs_sched, NULL, &sched_scaling_fops); ++#endif // CONFIG_SCHED_BORE + debugfs_create_u32("migration_cost_ns", 0644, debugfs_sched, &sysctl_sched_migration_cost); + debugfs_create_u32("nr_migrate", 0644, debugfs_sched, &sysctl_sched_nr_migrate); + +@@ -596,6 +648,9 @@ print_task(struct seq_file *m, struct rq *rq, struct task_struct *p) + SPLIT_NS(schedstat_val_or_zero(p->stats.sum_sleep_runtime)), + SPLIT_NS(schedstat_val_or_zero(p->stats.sum_block_runtime))); + ++#ifdef CONFIG_SCHED_BORE ++ SEQ_printf(m, " %2d", p->se.burst_score); ++#endif // CONFIG_SCHED_BORE + #ifdef CONFIG_NUMA_BALANCING + SEQ_printf(m, " %d %d", task_node(p), task_numa_group_id(p)); + #endif +@@ -1069,6 +1124,9 @@ void proc_sched_show_task(struct task_struct *p, struct pid_namespace *ns, + + P(se.load.weight); + #ifdef CONFIG_SMP ++#ifdef CONFIG_SCHED_BORE ++ P(se.burst_score); ++#endif // CONFIG_SCHED_BORE + P(se.avg.load_sum); + P(se.avg.runnable_sum); + P(se.avg.util_sum); +diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c +index 32f68ec1e..197fd68e3 100644 +--- a/kernel/sched/fair.c ++++ b/kernel/sched/fair.c +@@ -19,6 +19,9 @@ + * + * Adaptive scheduling granularity, math enhancements by Peter Zijlstra + * Copyright (C) 2007 Red Hat, Inc., Peter Zijlstra ++ * ++ * Burst-Oriented Response Enhancer (BORE) CPU Scheduler ++ * Copyright (C) 2021-2024 Masahito Suzuki + */ + #include + #include +@@ -64,28 +67,182 @@ + * SCHED_TUNABLESCALING_LOG - scaled logarithmical, *1+ilog(ncpus) + * SCHED_TUNABLESCALING_LINEAR - scaled linear, *ncpus + * +- * (default SCHED_TUNABLESCALING_LOG = *(1+ilog(ncpus)) ++ * (BORE default SCHED_TUNABLESCALING_NONE = *1 constant) ++ * (EEVDF default SCHED_TUNABLESCALING_LOG = *(1+ilog(ncpus)) + */ ++#ifdef CONFIG_SCHED_BORE ++unsigned int sysctl_sched_tunable_scaling = SCHED_TUNABLESCALING_NONE; ++#else // !CONFIG_SCHED_BORE + unsigned int sysctl_sched_tunable_scaling = SCHED_TUNABLESCALING_LOG; ++#endif // CONFIG_SCHED_BORE + + /* + * Minimal preemption granularity for CPU-bound tasks: + * +- * (default: 0.75 msec * (1 + ilog(ncpus)), units: nanoseconds) ++ * (BORE default: max(1 sec / HZ, min_base_slice) constant, units: nanoseconds) ++ * (EEVDF default: 0.75 msec * (1 + ilog(ncpus)), units: nanoseconds) + */ +-#ifdef CONFIG_CACHY +-unsigned int sysctl_sched_base_slice = 350000ULL; +-static unsigned int normalized_sysctl_sched_base_slice = 350000ULL; +-#else ++#ifdef CONFIG_SCHED_BORE ++unsigned int sysctl_sched_base_slice = 1000000000ULL / HZ; ++static unsigned int configured_sched_base_slice = 1000000000ULL / HZ; ++unsigned int sysctl_sched_min_base_slice = CONFIG_MIN_BASE_SLICE_NS; ++#else // !CONFIG_SCHED_BORE + unsigned int sysctl_sched_base_slice = 750000ULL; + static unsigned int normalized_sysctl_sched_base_slice = 750000ULL; +-#endif ++#endif // CONFIG_SCHED_BORE + +-#ifdef CONFIG_CACHY +-const_debug unsigned int sysctl_sched_migration_cost = 300000UL; +-#else + const_debug unsigned int sysctl_sched_migration_cost = 500000UL; +-#endif ++ ++#ifdef CONFIG_SCHED_BORE ++u8 __read_mostly sched_bore = 1; ++u8 __read_mostly sched_burst_exclude_kthreads = 1; ++u8 __read_mostly sched_burst_smoothness_long = 1; ++u8 __read_mostly sched_burst_smoothness_short = 0; ++u8 __read_mostly sched_burst_fork_atavistic = 2; ++u8 __read_mostly sched_burst_penalty_offset = 22; ++uint __read_mostly sched_burst_penalty_scale = 1280; ++uint __read_mostly sched_burst_cache_lifetime = 60000000; ++uint __read_mostly sched_deadline_boost_mask = ENQUEUE_INITIAL ++ | ENQUEUE_WAKEUP; ++uint __read_mostly sched_deadline_preserve_mask = ENQUEUE_RESTORE ++ | ENQUEUE_MIGRATED; ++static int __maybe_unused sixty_four = 64; ++static int __maybe_unused maxval_12_bits = 4095; ++ ++#define MAX_BURST_PENALTY (39U <<2) ++ ++static inline u32 log2plus1_u64_u32f8(u64 v) { ++ u32 msb = fls64(v); ++ s32 excess_bits = msb - 9; ++ u8 fractional = (0 <= excess_bits)? v >> excess_bits: v << -excess_bits; ++ return msb << 8 | fractional; ++} ++ ++static inline u32 calc_burst_penalty(u64 burst_time) { ++ u32 greed, tolerance, penalty, scaled_penalty; ++ ++ greed = log2plus1_u64_u32f8(burst_time); ++ tolerance = sched_burst_penalty_offset << 8; ++ penalty = max(0, (s32)greed - (s32)tolerance); ++ scaled_penalty = penalty * sched_burst_penalty_scale >> 16; ++ ++ return min(MAX_BURST_PENALTY, scaled_penalty); ++} ++ ++static inline u64 __scale_slice(u64 delta, u8 score) { ++ return mul_u64_u32_shr(delta, sched_prio_to_wmult[score], 22); ++} ++ ++static inline u64 __unscale_slice(u64 delta, u8 score) { ++ return mul_u64_u32_shr(delta, sched_prio_to_weight[score], 10); ++} ++ ++static void reweight_entity( ++ struct cfs_rq *cfs_rq, struct sched_entity *se, unsigned long weight); ++ ++static void reweight_task_by_prio(struct task_struct *p, int prio) ++{ ++ struct sched_entity *se = &p->se; ++ struct cfs_rq *cfs_rq = cfs_rq_of(se); ++ struct load_weight *load = &se->load; ++ unsigned long weight = scale_load(sched_prio_to_weight[prio]); ++ ++ reweight_entity(cfs_rq, se, weight); ++ load->inv_weight = sched_prio_to_wmult[prio]; ++} ++ ++static inline u8 effective_prio(struct task_struct *p) { ++ u8 prio = p->static_prio - MAX_RT_PRIO; ++ ++ if (likely(sched_bore)) ++ prio += p->se.burst_score; ++ return min(39, prio); ++} ++ ++static void update_burst_score(struct sched_entity *se) { ++ if (!entity_is_task(se)) return; ++ struct task_struct *p = task_of(se); ++ u8 prev_prio = effective_prio(p); ++ ++ u8 burst_score = 0; ++ if (!(sched_burst_exclude_kthreads && (p->flags & PF_KTHREAD))) ++ burst_score = se->burst_penalty >> 2; ++ ++ se->burst_score = burst_score; ++ ++ u8 new_prio = effective_prio(p); ++ if (new_prio != prev_prio) ++ reweight_task_by_prio(p, new_prio); ++} ++ ++static void update_burst_penalty(struct sched_entity *se) { ++ se->curr_burst_penalty = calc_burst_penalty(se->burst_time); ++ se->burst_penalty = max(se->prev_burst_penalty, se->curr_burst_penalty); ++ update_burst_score(se); ++} ++ ++static inline u32 binary_smooth(u32 new, u32 old) { ++ int increment = new - old; ++ return (0 <= increment)? ++ old + ( increment >> (int)sched_burst_smoothness_long): ++ old - (-increment >> (int)sched_burst_smoothness_short); ++} ++ ++static void restart_burst(struct sched_entity *se) { ++ se->burst_penalty = se->prev_burst_penalty = ++ binary_smooth(se->curr_burst_penalty, se->prev_burst_penalty); ++ se->curr_burst_penalty = 0; ++ se->burst_time = 0; ++ update_burst_score(se); ++} ++ ++static void restart_burst_rescale_deadline(struct sched_entity *se) { ++ s64 vscaled, wremain, vremain = se->deadline - se->vruntime; ++ struct task_struct *p = task_of(se); ++ u8 prev_prio = effective_prio(p); ++ restart_burst(se); ++ u8 new_prio = effective_prio(p); ++ if (prev_prio > new_prio) { ++ wremain = __unscale_slice(abs(vremain), prev_prio); ++ vscaled = __scale_slice(wremain, new_prio); ++ if (unlikely(vremain < 0)) ++ vscaled = -vscaled; ++ se->deadline = se->vruntime + vscaled; ++ } ++} ++ ++static void reset_task_weights_bore(void) { ++ struct task_struct *task; ++ struct rq *rq; ++ struct rq_flags rf; ++ ++ write_lock_irq(&tasklist_lock); ++ ++ for_each_process(task) { ++ rq = task_rq(task); ++ ++ rq_lock_irqsave(rq, &rf); ++ ++ reweight_task_by_prio(task, effective_prio(task)); ++ ++ rq_unlock_irqrestore(rq, &rf); ++ } ++ ++ write_unlock_irq(&tasklist_lock); ++} ++ ++int sched_bore_update_handler(struct ctl_table *table, int write, ++ void __user *buffer, size_t *lenp, loff_t *ppos) ++{ ++ int ret = proc_dou8vec_minmax(table, write, buffer, lenp, ppos); ++ if (ret || !write) ++ return ret; ++ ++ reset_task_weights_bore(); ++ ++ return 0; ++} ++#endif // CONFIG_SCHED_BORE + + static int __init setup_sched_thermal_decay_shift(char *str) + { +@@ -130,12 +287,8 @@ int __weak arch_asym_cpu_priority(int cpu) + * + * (default: 5 msec, units: microseconds) + */ +-#ifdef CONFIG_CACHY +-static unsigned int sysctl_sched_cfs_bandwidth_slice = 3000UL; +-#else + static unsigned int sysctl_sched_cfs_bandwidth_slice = 5000UL; + #endif +-#endif + + #ifdef CONFIG_NUMA_BALANCING + /* Restrict the NUMA promotion throughput (MB/s) for each target node. */ +@@ -144,6 +297,92 @@ static unsigned int sysctl_numa_balancing_promote_rate_limit = 65536; + + #ifdef CONFIG_SYSCTL + static struct ctl_table sched_fair_sysctls[] = { ++#ifdef CONFIG_SCHED_BORE ++ { ++ .procname = "sched_bore", ++ .data = &sched_bore, ++ .maxlen = sizeof(u8), ++ .mode = 0644, ++ .proc_handler = sched_bore_update_handler, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = SYSCTL_ONE, ++ }, ++ { ++ .procname = "sched_burst_exclude_kthreads", ++ .data = &sched_burst_exclude_kthreads, ++ .maxlen = sizeof(u8), ++ .mode = 0644, ++ .proc_handler = proc_dou8vec_minmax, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = SYSCTL_ONE, ++ }, ++ { ++ .procname = "sched_burst_smoothness_long", ++ .data = &sched_burst_smoothness_long, ++ .maxlen = sizeof(u8), ++ .mode = 0644, ++ .proc_handler = proc_dou8vec_minmax, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = SYSCTL_ONE, ++ }, ++ { ++ .procname = "sched_burst_smoothness_short", ++ .data = &sched_burst_smoothness_short, ++ .maxlen = sizeof(u8), ++ .mode = 0644, ++ .proc_handler = proc_dou8vec_minmax, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = SYSCTL_ONE, ++ }, ++ { ++ .procname = "sched_burst_fork_atavistic", ++ .data = &sched_burst_fork_atavistic, ++ .maxlen = sizeof(u8), ++ .mode = 0644, ++ .proc_handler = proc_dou8vec_minmax, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = SYSCTL_THREE, ++ }, ++ { ++ .procname = "sched_burst_penalty_offset", ++ .data = &sched_burst_penalty_offset, ++ .maxlen = sizeof(u8), ++ .mode = 0644, ++ .proc_handler = proc_dou8vec_minmax, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = &sixty_four, ++ }, ++ { ++ .procname = "sched_burst_penalty_scale", ++ .data = &sched_burst_penalty_scale, ++ .maxlen = sizeof(uint), ++ .mode = 0644, ++ .proc_handler = proc_douintvec_minmax, ++ .extra1 = SYSCTL_ZERO, ++ .extra2 = &maxval_12_bits, ++ }, ++ { ++ .procname = "sched_burst_cache_lifetime", ++ .data = &sched_burst_cache_lifetime, ++ .maxlen = sizeof(uint), ++ .mode = 0644, ++ .proc_handler = proc_douintvec, ++ }, ++ { ++ .procname = "sched_deadline_boost_mask", ++ .data = &sched_deadline_boost_mask, ++ .maxlen = sizeof(uint), ++ .mode = 0644, ++ .proc_handler = proc_douintvec, ++ }, ++ { ++ .procname = "sched_deadline_preserve_mask", ++ .data = &sched_deadline_preserve_mask, ++ .maxlen = sizeof(uint), ++ .mode = 0644, ++ .proc_handler = proc_douintvec, ++ }, ++#endif // CONFIG_SCHED_BORE + #ifdef CONFIG_CFS_BANDWIDTH + { + .procname = "sched_cfs_bandwidth_slice_us", +@@ -201,6 +440,13 @@ static inline void update_load_set(struct load_weight *lw, unsigned long w) + * + * This idea comes from the SD scheduler of Con Kolivas: + */ ++#ifdef CONFIG_SCHED_BORE ++static void update_sysctl(void) { ++ sysctl_sched_base_slice = ++ max(sysctl_sched_min_base_slice, configured_sched_base_slice); ++} ++void sched_update_min_base_slice(void) { update_sysctl(); } ++#else // !CONFIG_SCHED_BORE + static unsigned int get_update_sysctl_factor(void) + { + unsigned int cpus = min_t(unsigned int, num_online_cpus(), 8); +@@ -231,6 +477,7 @@ static void update_sysctl(void) + SET_SYSCTL(sched_base_slice); + #undef SET_SYSCTL + } ++#endif // CONFIG_SCHED_BORE + + void __init sched_init_granularity(void) + { +@@ -708,6 +955,10 @@ static s64 entity_lag(u64 avruntime, struct sched_entity *se) + + vlag = avruntime - se->vruntime; + limit = calc_delta_fair(max_t(u64, 2*se->slice, TICK_NSEC), se); ++#ifdef CONFIG_SCHED_BORE ++ if (likely(sched_bore)) ++ limit >>= 1; ++#endif // CONFIG_SCHED_BORE + + return clamp(vlag, -limit, limit); + } +@@ -868,6 +1119,39 @@ struct sched_entity *__pick_first_entity(struct cfs_rq *cfs_rq) + return __node_2_se(left); + } + ++static inline bool pick_curr(struct cfs_rq *cfs_rq, ++ struct sched_entity *curr, struct sched_entity *wakee) ++{ ++ /* ++ * Nothing to preserve... ++ */ ++ if (!curr || !sched_feat(RESPECT_SLICE)) ++ return false; ++ ++ /* ++ * Allow preemption at the 0-lag point -- even if not all of the slice ++ * is consumed. Note: placement of positive lag can push V left and render ++ * @curr instantly ineligible irrespective the time on-cpu. ++ */ ++ if (sched_feat(RUN_TO_PARITY) && !entity_eligible(cfs_rq, curr)) ++ return false; ++ ++ /* ++ * Don't preserve @curr when the @wakee has a shorter slice and earlier ++ * deadline. IOW, explicitly allow preemption. ++ */ ++ if (sched_feat(PREEMPT_SHORT) && wakee && ++ wakee->slice < curr->slice && ++ (s64)(wakee->deadline - curr->deadline) < 0) ++ return false; ++ ++ /* ++ * Preserve @curr to allow it to finish its first slice. ++ * See the HACK in set_next_entity(). ++ */ ++ return curr->vlag == curr->deadline; ++} ++ + /* + * Earliest Eligible Virtual Deadline First + * +@@ -887,28 +1171,27 @@ struct sched_entity *__pick_first_entity(struct cfs_rq *cfs_rq) + * + * Which allows tree pruning through eligibility. + */ +-static struct sched_entity *pick_eevdf(struct cfs_rq *cfs_rq) ++static struct sched_entity *pick_eevdf(struct cfs_rq *cfs_rq, struct sched_entity *wakee) + { + struct rb_node *node = cfs_rq->tasks_timeline.rb_root.rb_node; + struct sched_entity *se = __pick_first_entity(cfs_rq); + struct sched_entity *curr = cfs_rq->curr; + struct sched_entity *best = NULL; + ++ if (curr && !curr->on_rq) ++ curr = NULL; ++ + /* + * We can safely skip eligibility check if there is only one entity + * in this cfs_rq, saving some cycles. + */ + if (cfs_rq->nr_running == 1) +- return curr && curr->on_rq ? curr : se; +- +- if (curr && (!curr->on_rq || !entity_eligible(cfs_rq, curr))) +- curr = NULL; ++ return curr ?: se; + + /* +- * Once selected, run a task until it either becomes non-eligible or +- * until it gets a new slice. See the HACK in set_next_entity(). ++ * Preserve @curr to let it finish its slice. + */ +- if (sched_feat(RUN_TO_PARITY) && curr && curr->vlag == curr->deadline) ++ if (pick_curr(cfs_rq, curr, wakee)) + return curr; + + /* Pick the leftmost entity if it's eligible */ +@@ -967,6 +1250,7 @@ struct sched_entity *__pick_last_entity(struct cfs_rq *cfs_rq) + * Scheduling class statistics methods: + */ + #ifdef CONFIG_SMP ++#if !defined(CONFIG_SCHED_BORE) + int sched_update_scaling(void) + { + unsigned int factor = get_update_sysctl_factor(); +@@ -978,6 +1262,7 @@ int sched_update_scaling(void) + + return 0; + } ++#endif // CONFIG_SCHED_BORE + #endif + #endif + +@@ -1178,6 +1463,10 @@ static void update_curr(struct cfs_rq *cfs_rq) + if (unlikely(delta_exec <= 0)) + return; + ++#ifdef CONFIG_SCHED_BORE ++ curr->burst_time += delta_exec; ++ update_burst_penalty(curr); ++#endif // CONFIG_SCHED_BORE + curr->vruntime += calc_delta_fair(delta_exec, curr); + update_deadline(cfs_rq, curr); + update_min_vruntime(cfs_rq); +@@ -5193,6 +5482,12 @@ place_entity(struct cfs_rq *cfs_rq, struct sched_entity *se, int flags) + s64 lag = 0; + + se->slice = sysctl_sched_base_slice; ++#ifdef CONFIG_SCHED_BORE ++ if (likely(sched_bore) && ++ (flags & ~sched_deadline_boost_mask & sched_deadline_preserve_mask)) ++ vslice = se->deadline - se->vruntime; ++ else ++#endif // CONFIG_SCHED_BORE + vslice = calc_delta_fair(se->slice, se); + + /* +@@ -5203,6 +5498,9 @@ place_entity(struct cfs_rq *cfs_rq, struct sched_entity *se, int flags) + * + * EEVDF: placement strategy #1 / #2 + */ ++#ifdef CONFIG_SCHED_BORE ++ if (se->vlag) ++#endif // CONFIG_SCHED_BORE + if (sched_feat(PLACE_LAG) && cfs_rq->nr_running) { + struct sched_entity *curr = cfs_rq->curr; + unsigned long load; +@@ -5278,6 +5576,13 @@ place_entity(struct cfs_rq *cfs_rq, struct sched_entity *se, int flags) + * on average, halfway through their slice, as such start tasks + * off with half a slice to ease into the competition. + */ ++#ifdef CONFIG_SCHED_BORE ++ if (likely(sched_bore)) { ++ if (flags & sched_deadline_boost_mask) ++ vslice /= 2; ++ } ++ else ++#endif // CONFIG_SCHED_BORE + if (sched_feat(PLACE_DEADLINE_INITIAL) && (flags & ENQUEUE_INITIAL)) + vslice /= 2; + +@@ -5492,7 +5797,7 @@ pick_next_entity(struct cfs_rq *cfs_rq) + cfs_rq->next && entity_eligible(cfs_rq, cfs_rq->next)) + return cfs_rq->next; + +- return pick_eevdf(cfs_rq); ++ return pick_eevdf(cfs_rq, NULL); + } + + static bool check_cfs_rq_runtime(struct cfs_rq *cfs_rq); +@@ -6860,6 +7165,14 @@ static void dequeue_task_fair(struct rq *rq, struct task_struct *p, int flags) + bool was_sched_idle = sched_idle_rq(rq); + + util_est_dequeue(&rq->cfs, p); ++#ifdef CONFIG_SCHED_BORE ++ if (task_sleep) { ++ cfs_rq = cfs_rq_of(se); ++ if (cfs_rq->curr == se) ++ update_curr(cfs_rq); ++ restart_burst(se); ++ } ++#endif // CONFIG_SCHED_BORE + + for_each_sched_entity(se) { + cfs_rq = cfs_rq_of(se); +@@ -8428,7 +8741,7 @@ static void check_preempt_wakeup_fair(struct rq *rq, struct task_struct *p, int + /* + * XXX pick_eevdf(cfs_rq) != se ? + */ +- if (pick_eevdf(cfs_rq) == pse) ++ if (pick_eevdf(cfs_rq, pse) == pse) + goto preempt; + + return; +@@ -8646,16 +8959,25 @@ static void yield_task_fair(struct rq *rq) + /* + * Are we the only task in the tree? + */ ++#if !defined(CONFIG_SCHED_BORE) + if (unlikely(rq->nr_running == 1)) + return; + + clear_buddies(cfs_rq, se); ++#endif // CONFIG_SCHED_BORE + + update_rq_clock(rq); + /* + * Update run-time statistics of the 'current'. + */ + update_curr(cfs_rq); ++#ifdef CONFIG_SCHED_BORE ++ restart_burst_rescale_deadline(se); ++ if (unlikely(rq->nr_running == 1)) ++ return; ++ ++ clear_buddies(cfs_rq, se); ++#endif // CONFIG_SCHED_BORE + /* + * Tell update_rq_clock() that we've just updated, + * so we don't do microscopic update in schedule() +@@ -12713,6 +13035,9 @@ static void task_fork_fair(struct task_struct *p) + curr = cfs_rq->curr; + if (curr) + update_curr(cfs_rq); ++#ifdef CONFIG_SCHED_BORE ++ update_burst_score(se); ++#endif // CONFIG_SCHED_BORE + place_entity(cfs_rq, se, ENQUEUE_INITIAL); + rq_unlock(rq, &rf); + } +diff --git a/kernel/sched/features.h b/kernel/sched/features.h +index 143f55df8..bfeb9f653 100644 +--- a/kernel/sched/features.h ++++ b/kernel/sched/features.h +@@ -5,8 +5,26 @@ + * sleep+wake cycles. EEVDF placement strategy #1, #2 if disabled. + */ + SCHED_FEAT(PLACE_LAG, true) ++/* ++ * Give new tasks half a slice to ease into the competition. ++ */ + SCHED_FEAT(PLACE_DEADLINE_INITIAL, true) +-SCHED_FEAT(RUN_TO_PARITY, true) ++/* ++ * Inhibit (wakeup) preemption until the current task has exhausted its slice. ++ */ ++#ifdef CONFIG_SCHED_BORE ++SCHED_FEAT(RESPECT_SLICE, false) ++#else // !CONFIG_SCHED_BORE ++SCHED_FEAT(RESPECT_SLICE, true) ++#endif // CONFIG_SCHED_BORE ++/* ++ * Relax RESPECT_SLICE to allow preemption once current has reached 0-lag. ++ */ ++SCHED_FEAT(RUN_TO_PARITY, false) ++/* ++ * Allow tasks with a shorter slice to disregard RESPECT_SLICE ++ */ ++SCHED_FEAT(PREEMPT_SHORT, true) + + /* + * Prefer to schedule the task we woke last (assuming it failed +diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h +index fbcd2ddbf..e187aaa59 100644 +--- a/kernel/sched/sched.h ++++ b/kernel/sched/sched.h +@@ -2041,7 +2041,11 @@ static inline void dirty_sched_domain_sysctl(int cpu) + } + #endif + ++#ifdef CONFIG_SCHED_BORE ++extern void sched_update_min_base_slice(void); ++#else // !CONFIG_SCHED_BORE + extern int sched_update_scaling(void); ++#endif // CONFIG_SCHED_BORE + + static inline const struct cpumask *task_user_cpus(struct task_struct *p) + { +@@ -2672,6 +2676,9 @@ extern const_debug unsigned int sysctl_sched_nr_migrate; + extern const_debug unsigned int sysctl_sched_migration_cost; + + extern unsigned int sysctl_sched_base_slice; ++#ifdef CONFIG_SCHED_BORE ++extern unsigned int sysctl_sched_min_base_slice; ++#endif // CONFIG_SCHED_BORE + + #ifdef CONFIG_SCHED_DEBUG + extern int sysctl_resched_latency_warn_ms; +-- +2.45.2.606.g9005149a4a diff --git a/patches/series b/patches/series new file mode 100644 index 0000000..6fc82b3 --- /dev/null +++ b/patches/series @@ -0,0 +1,3 @@ +0001-cachyos-base-all.patch +0002-sched-ext.patch +0003-bore-cachy-ext.patch \ No newline at end of file