题意
给定一个 \(3\times n\) 的网格,每个格子 \((i,j)\) 内有一个数 \(a_{i,j}\)。一个人初始分数为 \(0\),在位置 \((1,1)\) 处,每次可以向右或向下走一格,目标是到达 \((3,n)\)。当走到格子 \((i,j)\) 时,这个人的分数会增加 \(a_{i,j}\)。
起初,只有第一、三行的格子是可用的。现在给出了 \(q\) 个区间 \([l_i,r_i]\) 和 \(k_i\),表示这个人可以减少 \(k_i\) 的得分让第二行第 \(l_i\) 到 \(r_i\) 个格子变得可用。求达成目标的最大得分。\(1\leq n,q\leq 5\times 10^5\),\(-10^9\leq a_{i,j}\leq 10^9\),\(1\leq k_i\leq 10^9\)。
题解
下文中令 \(s_{i,j}=\sum_{k=1}^ja_{i,k}\)。
考虑 DP。令 \(f_i\) 表示从 \((1,1)\) 走到 \((2,i)\),钦定 \(i\) 是某个区间的右端点的最大得分。我们从小到大枚举 \(i\),遍历所有右端点为 \(i\) 的区间 \([l,i]\),考虑其转移:
- 从前面某个区间的右端点对应的 \((2,j)\) 向右走到 \((2,i)\):\[f_i\leftarrow \max_{l-1\leq j\leq i-1}\{f_j-s_{2,j}\}+s_{2,i}-k \]
- 从 \((1,j)\) 向下走到区间内的 \((2,j)\),然后向右走到 \((2,i)\)\[f_i\leftarrow \max_{l\leq j\leq i}\{s_{1,j}-s_{2,j-1}\}+s_{2,i}-k \]
后者显然是 RMQ,可以使用 ST 表维护。前者也可以用线段树维护 \(f_j-s_{2,j}\) 的区间 \(\max\),不过我们注意到 \(>i\) 的部分是没有有效值的,所以实际上是后缀询问,可以使用 BIT 维护。
考虑如何统计答案。问题在于,从第二行走到第三行的拐点,不一定在某个区间的右端点上。但是我们可以发现,在第二行解锁的区间中,最多只有一个区间的右端点没有被走到。因为如果有多个没走满的区间,我们显然可以删掉左端点最大的那个。于是我们枚举没走满的区间 \([l,r]\),做和前面类似的分类讨论:
- 从前面某个区间的右端点对应的 \((2,i)\) 走到 \((2,j)\):\[\begin{align*} ans&\leftarrow \max_{l-1\leq i\leq j\leq r}\{f_i+s_{2,j}-s_{2,i}+s_{3,n}-s_{3,j-1}\}\\ &=\max_{l-1\leq i\leq j\leq r}\{(f_i-s_{2,i})+(s_{2,j}-s_{3,j-1})\}+s_{3,n} \end{align*} \]
- 从 \((1,i)\) 向下走到 \((2,i)\),再向右走到 \((2,j)\):\[\begin{align*} ans&\leftarrow \max_{l\leq i\leq j\leq r}\{s_{1,i}+s_{2,j}-s_{2,i-1}+s_{3,n}-s_{3,j-1}\}\\ &=\max_{l-1\leq i\leq j\leq r}\{(s_{1,i}-s_{2,i-1})+(s_{2,j}-s_{3,j-1})\}+s_{3,n} \end{align*} \]
两者都是形如区间查询 \(\max_{l\leq i\leq j\leq r}\{A(i)+B(j)\}\) 的形式,可以使用线段树维护,在节点上维护区间的 \(A,B\) 最大值和 \(\max_{l\leq i\leq j\leq r}\{A(i)+B(j)\}\) 即可。
时间复杂度 \(\mathcal{O}((n+q)\log{n})\)。
代码
#include <bits/stdc++.h>using namespace std;#define lowbit(x) ((x) & -(x))
typedef long long ll;
typedef pair<int, int> pii;
const int N = 5e5 + 5;
const ll INF = 1e18;namespace IO {const int S = 1 << 24, lm = 1 << 23;char bi[S + 5], *p1 = bi, *p2 = bi, ch;int s;#define gc() (p1 == p2 && (p2 = (p1 = bi) + fread(bi, 1, 1 << 23, stdin), p1 == p2) ? EOF : *p1++)inline ll rd() {s = 1; char ch;while (ch = gc(), (ch < '0')) if (ch == '-') s = -1;ll x = ch ^ 48;while (ch = gc(), (ch >= '0')) x = (x << 3) + (x << 1) + (ch ^ 48);return s == 1 ? x : -x;}
}
using IO::rd;template<typename T> inline void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> inline void chk_max(T &x, T y) { x = max(x, y); }int n, m, a[3][N];
ll val1[N], val2[N];
ll ans = -INF, pre[3][N], f[N];
struct Range { int l, r, k; } rg[N];
basic_string<Range> vec[N];struct BIT {ll c[N];inline void init() { fill(c + 1, c + n + 1, -INF); }inline ll query(int x) {ll res = -INF;for (; x <= n; x += lowbit(x)) chk_max(res, c[x]);return res;}inline void upd(int x, ll v) { for (; x; x -= lowbit(x)) chk_max(c[x], v); }
} ft1, ft2;
struct SegTree2 {
#define ls(p) (p << 1)
#define rs(p) (p << 1 | 1)struct Node {ll mx1, mx2, dat;Node() : mx1(-INF), mx2(-INF), dat(-INF) {}Node(ll mx1, ll mx2, ll dat) : mx1(mx1), mx2(mx2), dat(dat) {}Node operator+(const Node &x) const { return {max(mx1, x.mx1), max(mx2, x.mx2), max({dat, x.dat, mx1 + x.mx2})}; }} nodes[N << 2];inline void push_up(int p) { nodes[p] = nodes[ls(p)] + nodes[rs(p)]; }inline void build(int p, int l, int r) {if (l == r) return nodes[p] = {val1[l], val2[l], val1[l] + val2[l]}, void();int mid = l + r >> 1;build(ls(p), l, mid), build(rs(p), mid + 1, r);push_up(p);}inline Node query(int p, int l, int r, int x, int y) {if (x <= l && y >= r) return nodes[p];int mid = l + r >> 1; Node res;if (x <= mid) res = query(ls(p), l, mid, x, y);if (y > mid) res = res + query(rs(p), mid + 1, r, x, y);return res;}
#undef ls
#undef rs
} sgt3;int main() {ios::sync_with_stdio(0), cin.tie(0);n = rd(), m = rd();for (int i = 0; i < 3; ++i) for (int j = 1; j <= n; ++j) a[i][j] = rd(), pre[i][j] = pre[i][j - 1] + a[i][j];for (int i = 1, l, r, k; i <= m; ++i) l = rd(), r = rd(), k = rd(), vec[r] += rg[i] = {l, r, k};fill(f + 1, f + n + 1, -INF), ft1.init(), ft2.init();for (int i = 1; i <= n; ++i) {ft2.upd(i, pre[0][i] - pre[1][i - 1]);for (Range it : vec[i]) {int l = it.l, k = it.k;if (i > 1) chk_max(f[i], ft1.query(max(l - 1, 1)) + pre[1][i] - k);chk_max(f[i], ft2.query(l) + pre[1][i] - k);}ft1.upd(i, f[i] - pre[1][i]);}for (int i = 1; i <= n; ++i) val1[i] = f[i] - pre[1][i], val2[i] = pre[1][i] - pre[2][i - 1];sgt3.build(1, 1, n);for (int i = 1; i <= m; ++i) chk_max(ans, sgt3.query(1, 1, n, max(rg[i].l - 1, 1), rg[i].r).dat + pre[2][n] - rg[i].k);for (int i = 1; i <= n; ++i) val1[i] = pre[0][i] - pre[1][i - 1];sgt3.build(1, 1, n);for (int i = 1; i <= m; ++i) chk_max(ans, sgt3.query(1, 1, n, rg[i].l, rg[i].r).dat + pre[2][n] - rg[i].k);cout << ans;return 0;
}