根据贪心,不难想到每次会把最长队伍末尾的那辆车移动到最短队伍的末尾。但由于 k k k 的存在,会导致一些冗余移动的存在。设需要挪动 C C C 辆车,则怒气值可以表示为 f ( C ) + k C f(C) + kC f(C)+kC,其中 f ( C ) f(C) f(C) 是排队所产生的怒气值, k C kC kC 为变道产生的额外怒气值。仔细分析以后,可以发现这是一个凸函数,因此考虑三分答案。
一开始想要三分需要挪车的最短长度 y y y,但是不能忽略 k k k 的影响,有些队伍的长度虽然 > y > y >y,但挪动不移动会更优。于是三分挪动车辆的数量才是最优的。
具体来说,可以枚举哪些队伍的车辆会减少/增加。若现在考虑会减少的队伍的车辆,给 a i a_i ai 排序后,设当前最长队伍的车辆数为 x x x,次长的为 y y y ( x ≠ y x \neq y x=y),然后长度为 x , y x,y x,y 的队伍的数量分别为 f x , f y f_x,f_y fx,fy。若共需要移动 C C C 辆车,则有两种情况:
-
C ≥ ( x − y ) × f x C \ge (x - y) \times f_x C≥(x−y)×fx,也就是说长度为 x x x 的车可以直接变为 y y y, C ← C − ( x − y ) × f x ; f y ← f x + f y ; f x ← 0 C \leftarrow C - (x - y) \times f_x;\ f_y \leftarrow f_x + f_y;\ f_x \leftarrow 0 C←C−(x−y)×fx; fy←fx+fy; fx←0。
-
C < ( x − y ) × f x C < (x - y) \times f_x C<(x−y)×fx,此时会产生新的队伍长度,也就是 C ← 0 ; f x − ⌊ C f x ⌋ − 1 ← f x − ⌊ C f x ⌋ − 1 + C m o d f x ; ← f x − ⌊ C f x ⌋ + ( f x − C m o d f x ) C \leftarrow 0;\ f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} + C \bmod f_x;\ \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor} + (f_x - C \bmod f_x) C←0; fx−⌊fxC⌋−1←fx−⌊fxC⌋−1+Cmodfx; ←fx−⌊fxC⌋+(fx−Cmodfx)。
可以发现最后队伍长度的种类数不会超过 n + 2 n + 2 n+2,因此这是 O ( n ) O(n) O(n) 的。考虑增加的队伍的车辆同理,用 STL 来写会简单一点。但是由于多了一支 log \log log,实测会超时:
ll tot = sum * k,res = sum,number = sum;
set <int> s;map <int,int> bg,sm;
s.insert (-1e9);
for (int i = 1;i <= n;++i) s.insert (a[i]),++bg[a[i]];
while (sum)
{int x = *(--s.end ()),num = bg[x];s.erase (x);int y = *(--s.end ());if (sum >= 1ll * (x - y) * num){sum -= 1ll * (x - y) * num;bg[y] += num;bg[x] = 0;}else {bg[x] = 0;int tmp = sum % num;if (tmp) bg[x - sum / num - 1] += tmp;bg[x - sum / num] += num - tmp;sum = 0;}
}
s.clear ();
for (auto [x,num] : bg)if (num) s.insert (x),sm[x] = num;
s.insert (1e9);
while (res)
{int x = *s.begin (),num = sm[x];s.erase (x);int y = *s.begin ();if (res >= 1ll * (y - x) * num){res -= 1ll * (y - x) * num;sm[y] += num;sm[x] = 0;}else{sm[x] = 0;int tmp = res % num;if (tmp) sm[x + res / num + 1] += tmp;sm[x + res / num] += num - tmp;res = 0;}
}
for (auto [x,num] : sm) tot += 1ll * x * (x + 1) / 2 * num;
return tot;
};
再次思考可以发现 STL 的 log \log log 完全是多余的,可以通过数组来替代,但需要小心清空与去重的问题。最后的 AC 代码如下,时间复杂度 O ( n log n ) O(n \log n) O(nlogn):
#include <bits/stdc++.h>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 2e18
#define pii pair <int,int>
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int a[MAX],b[MAX];
vector <int> bg (1000001,0),sm (1000001,0);
void solve ()
{int n = read (),k = read ();ll ans = INF;for (int i = 1;i <= n;++i) a[i] = read ();sort (a + 1,a + 1 + n);auto check = [&] (ll sum) -> ll{ll tot = sum * k,res = sum;int cnt = 0;vector <int> p;for (int i = 1;i <= n;++i) p.push_back (a[i]);for (int i = 1;i <= n;++i) {if (!bg[a[i]]) b[++cnt] = a[i];++bg[a[i]];}b[0] = -1e9;while (sum > 0){int x = b[cnt--],num = bg[x];int y = b[cnt];if (sum >= 1ll * (x - y) * num){sum -= 1ll * (x - y) * num;bg[y] += num;bg[x] = 0;}else {bg[x] = 0;int tmp = sum % num;bg[x - sum / num] += num - tmp,p.push_back (x - sum / num);if (tmp) bg[x - sum / num - 1] += tmp,p.push_back (x - sum / num - 1);sum = 0;}}cnt = 0;for (auto v : p)if (bg[v]) b[++cnt] = v,sm[v] = bg[v],bg[v] = 0;p.clear ();for (int i = 1;i <= cnt;++i) p.push_back (b[i]);b[++cnt] = 1e9;cnt = 1;while (res > 0){int x = b[cnt++],num = sm[x];int y = b[cnt];if (res >= 1ll * (y - x) * num){res -= 1ll * (y - x) * num;sm[y] += num;sm[x] = 0;}else{sm[x] = 0;int tmp = res % num;if (tmp) sm[x + res / num + 1] += tmp,p.push_back (x + res / num + 1);sm[x + res / num] += num - tmp,p.push_back (x + res / num);res = 0;}}for (auto v : p) tot += 1ll * v * (v + 1) / 2 * sm[v],sm[v] = 0;return tot;};ll l = 0,r = accumulate (a + 1,a + n + 1,0ll);while (l < r){ll midl = l + (r - l) / 3,midr = r - (r - l) / 3;ll v1 = check (midl),v2 = check (midr);ans = min (ans,min (v1,v2));if (v1 <= v2) r = midr - 1;else l = midl + 1;}printf ("%lld\n",ans);
}
int main ()
{int t = read ();while (t--) solve ();return 0;
}
inline int read ()
{int s = 0;int f = 1;char ch = getchar ();while ((ch < '0' || ch > '9') && ch != EOF){if (ch == '-') f = -1;ch = getchar ();}while (ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar ();}return s * f;
}