给定 M
个长度为 N
的序列,从 每个序列 中 任取一个数 求和,可以构成 N ^ M
个和,求其中 最小的 N
个和。
(N ≤ 2000,M ≤ 1000
)
- 思路:
总思想:
根据数学归纳法,我们可以先求出 前 2
个序列中 任取一个数相加构成的 前 n
小和,把这 n
个和作为一个序列,再与 第 3
个序列求新的前 n
小和,依此类推,最终得到 m
个序列任取一个数相加构成的前 n
小和。
(将 m
个序列进行合并成 1
个序列。n
路归并)
我们可以 分为两步进行处理。
对于 第一步,我们发现,如果直接将 m
个序列合并成 1
个序列 不太好做,一个 常用的思想 是:每次将两个序列合并成一个,这样一来,合并 m - 1
次即可。
先处理前两个序列,我们 从每个序列都挑一个数,一共有 n ^ 2
种可能产生的和,我们 只保留前 n
小的和(因为我们最终要求的是 最小的 n
个和,因此 之后的都没必要保留下来)。
这样,每次都可以减少一个序列,只要做 m - 1
次,最终就只剩下一个序列。
对于 第二步,是本题最为关键的部分,在两个序列的情况下,我们已经得出 n ^ 2
个和,那么要取 前 n
个和,该如何实现?
当然,如果暴力求解,将 n ^ 2
个和全部求出,之后排个序取前 n
个显然是不可行的。
接下来考虑如何优化这一步。
做法:分组法(编程中一个很重要的思想)
先将 a1 ~ an
、b1 ~ bn
从小到大排序,我们知道从两序列中一共可以得到 n ^ 2
个和,我们 将这些和分为 n
组,每组中含 n
个元素。
第 1
组:a1 + b1、a1 + b2、...、a1 + bn
第 2
组:a2 + b1、a2 + b2、...、a2 + bn
…
第 n
组:an + b1、an + b2、...、an + bn
好处:对于 第 1
组 而言,因为 a1 ~ an
是有序的,所以 组中所有元素也都是有序的,直到 第 n
组 也是一样的有序(重要性质:每组内部有序),所以,每一组的第一个元素一定是该组的最小值。
有了上面的性质就好办了,我们若要 找 n ^ 2
个和 中 最小值,其实就是在 第 1
列 中(a1 + b1、a2 + b1、...、an + b1
)找到 最小值。假设我们找到了 第 2
组 中的 a2 + b1
作为最小值,我们将其 删去,此时 第 2
组中的最小值 就成了 a2 + b2
(因为 每组内部是有序的)。接下来找 下一个最小值(当前 第 2
小数),就是从这 n
个元素 中(a1 + b1、a2 + b2、...、an + b1
)中找到 最小值 即可。
以此类推,每次都是 从 所有组中最小值组成的集合中 找到 属于某一组 的 最小值,并 删掉,之后这组中的最小值就更新为 下一个元素,之后 将其存入备选集合中(备选集合永远只包含 n
个元素,即各个组中的最小值)执行 n
次,即可找到当前两个序列中 最小的 n
个和。
那么我们的备选集合适合用什么 数据结构 维护呢,由于我们需要 三个 *** 作:取得最小值,删掉最小值,加入新的数,自然想到可以用 小根堆 进行维护。
- 时间复杂度:
每次 找最小值 O(logn)
,执行 n
次,即 O(nlogn)
,乘上 m - 1
次合并,本题 总时间复杂度为 O(mnlogn)
- 代码:
#define _CRT_SECURE_NO_WARNINGS 1
#include <bits/stdc++.h>
using namespace std;
typedef vector<int> vi;
int m, n;
struct node
{
int sum;
int idxa;
int idxb;
bool operator< (const node& x) const {
return sum > x.sum;
}
};
vi merge(vi a, vi b)
{
vi c;
priority_queue<node> heap;
for (int i = 0; i < n; ++i)
{
int tmp = b[0] + a[i];
heap.push({ tmp, i, 0 });
}
for (int i = 0; i < n; ++i)
{
auto t = heap.top();
heap.pop();
c.push_back(t.sum);
heap.push({ a[t.idxa] + b[t.idxb + 1], t.idxa, t.idxb + 1 });
}
return c;
}
int main()
{
int T = 1; cin >> T;
while (T--)
{
vi a;
cin >> m >> n;
for (int i = 0; i < n; ++i)
{
int val; scanf("%d", &val);
a.push_back(val);
}
sort(a.begin(), a.end());
for (int i = 0; i < m - 1; ++i)
{
vi b;
for (int j = 0; j < n; ++j)
{
int vb; scanf("%d", &vb);
b.push_back(vb);
}
sort(b.begin(), b.end());
a = merge(a, b);
}
for (auto v : a) printf("%d ", v);
puts("");
}
return 0;
}
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)