合并石子
题目描述
在桌面从左至右横向摆放着 N 堆石子。每一堆石子都有着相同的颜色,颜色可能是颜色 0,颜色 1 或者颜色 2 中的其中一种。
现在要对石子进行合并,规定每次只能选择位置相邻并且颜色相同的两堆石子进行合并。合并后新堆的相对位置保持不变,新堆的石子数目为所选择的两堆石子数目之和,并且新堆石子的颜色也会发生循环式的变化。具体来说:两堆颜色 0 的石子合并后的石子堆为颜色 1,两堆颜色 1 的石子合并后的石子堆为颜色 2,两堆颜色 2 的石子合并后的石子堆为颜色 0。本次合并的花费为所选择的两堆石子的数目之和。
给出 N 堆石子以及他们的初始颜色,请问最少可以将它们合并为多少堆石子?如果有多种答案,选择其中合并总花费最小的一种,合并总花费指的是在所有的合并操作中产生的合并花费的总和。
输入格式
第一行一个正整数 N 表示石子堆数。
第二行包含 N 个用空格分隔的正整数,表示从左至右每一堆石子的数目。
第三行包含 N 个值为 0 或 1 或 2 的整数表示每堆石头的颜色。
输出格式
一行包含两个整数,用空格分隔。其中第一个整数表示合并后数目最少的石头堆数,第二个整数表示对应的最小花费。
样例输入
5
5 10 1 8 6
1 1 0 2 2
样例输出
2 44
题解
本题是一道明显的区间 DP 题目,因为他是一道经典题目基础上改的,增加了相同的颜色才可以合并,我们设f[i][j][c]表示区间i~j是C 颜色的石子时的最小花费,那么我们可以得出以下状态转移方程式:
f[l][r][(c+1)%3] = min(f[l][中间点][c] + f[中间点+1][r][c])
我们再多加一层这个中间点就可以了
另外在统计答案时,因为所有石子可能无法合并成一堆,所以f[1][n][c]并不是答案,我们设num[l][r] 代表 l~r 之间有多少堆,cnt[l][r] 代表l~r之间最小化费是多少, 然后最后直接合并区间就行,详细见代码
用记忆化写的
import java.util.*;
import java.io.*;
public class Main {
static final int maxn = 305;
static int n;
static long inf = Integer.MAX_VALUE;
static long[][][] f = new long[maxn][maxn][5];
static long[][] num = new long[maxn][maxn], cnt = new long[maxn][maxn]; // num[l][r] 代表 l~r 之间有多少堆,cnt[l][r] 代表l~r之间最小化费是多少
static int[] col = new int[maxn], a = new int[maxn], sum = new int[maxn];
static public void main(String[] args) throws Exception {
Read sc = new Read();
n = sc.nextInt();
for (int i = 1; i <= n; i++) {
a[i] = sc.nextInt();
sum[i] = sum[i - 1] + a[i]; // 前缀和方便计算区间
}
for (int i = 1; i <= n; i++)
col[i] = sc.nextInt();
for (int i = 1; i <= n; i++)
for (int j = i; j <= n; j++) {
cnt[i][j] = inf;
num[i][j] = j - i + 1; // 默认不能进行合并
for (int k = 0; k < 3; k++)
f[i][j][k] = -1;
}
for (int i = 1; i <= n; i++) {
cnt[i][i] = 0;
f[i][i][col[i]] = 0;
}
for (int i = 0; i < 3; i++) {
dfs(1, n, i);
}
for (int l = 1; l <= n; l++) { // 统计答案
for (int r = l; r <= n; r++) {
for (int i = l; i < r; i++) {
if (num[l][r] > num[l][i] + num[i + 1][r]) { // 要合并堆数少的
cnt[l][r] = cnt[l][i] + cnt[i + 1][r];
num[l][r] = num[l][i] + num[i + 1][r];
} else if (num[l][r] == num[l][i] + num[i + 1][r]) { // 一样的话就选花费少的
cnt[l][r] = Math.min(cnt[l][r], cnt[l][i] + cnt[i + 1][r]);
}
}
}
}
System.out.println(num[1][n] + " " + cnt[1][n]);
}
static long dfs(int l, int r, int c) {
if (f[l][r][c] != -1)
return f[l][r][c];
long ans = inf;
for (int i = l; i < r; i++) {
ans = Math.min(dfs(l, i, (c == 0 ? 2 : c - 1)) + dfs(i + 1, r, (c == 0 ? 2 : c - 1)) + sum[r] - sum[l - 1],
ans);
}
if (ans != inf)
num[l][r] = 1;
cnt[l][r] = Math.min(ans, cnt[l][r]);
f[l][r][c] = ans;
return ans;
}
}
class Read {
StreamTokenizer st = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
public int nextInt() throws Exception {
st.nextToken();
return (int) st.nval;
}
public String readLine() throws Exception {
st.nextToken();
return st.sval;
}
}
Comments NOTHING