题目描述
给定一个长度为 n 的序列 a,要求支持如下三个操作:
- 给定区间 [l,r],将区间内每个数都修改为 x。
- 给定区间 [l,r],将区间内每个数都加上 x。
- 给定区间 [l,r],求区间内的最大值。
输入格式
第一行是两个整数,依次表示序列的长度 n 和操作的个数 q。
第二行有 n 个整数,第 i 个整数表示序列中的第 i 个数 ai。
接下来 q 行,每行表示一个操作。每行首先有一个整数 op,表示操作的类型。
- 若 op=1,则接下来有三个整数 l,r,x,表示将区间 [l,r] 内的每个数都修改为 x。
- 若 op=2,则接下来有三个整数 l,r,x,表示将区间 [l,r] 内的每个数都加上 x。
- 若 op=3,则接下来有两个整数 l,r,表示查询区间 [l,r] 内的最大值。
输出格式
对于每个 op=3 的操作,输出一行一个整数表示答案。
输入输出样例
输入 #1复制
6 6 1 1 4 5 1 4 1 1 2 6 2 3 4 2 3 1 4 3 2 3 1 1 6 -1 3 1 6
输出 #1复制
7 6 -1
输入 #2复制
4 4 10 4 -3 -7 1 1 3 0 2 3 4 -4 1 2 4 -9 3 1 4
输出 #2复制
0
说明/提示
数据规模与约定
- 对于 10% 的数据,n=q=1。
- 对于 40% 的数据,n,q≤103。
- 对于 50% 的数据,0≤ai,x≤104。
- 对于 60% 的数据,op=1。
- 对于 90% 的数据,n,q≤105。
- 对于 100% 的数据,1≤n,q≤106,1≤l,r≤n,op∈{1,2,3},∣ai∣,∣x∣≤109。
提示
请注意大量数据读入对程序效率造成的影响。
代码实现:
#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
using namespace std;
typedef long long ll;
struct SegmentTreeNode {
ll max_val;
ll add;
ll set;
bool has_set;
};
class SegmentTree {
private:
vector<SegmentTreeNode> tree;
int n;
vector<ll> arr;
void build(int node, int start, int end) {
tree[node].add = 0;
tree[node].has_set = false;
if (start == end) {
tree[node].max_val = arr[start - 1];
tree[node].set = arr[start - 1];
} else {
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
tree[node].max_val = max(tree[2 * node].max_val, tree[2 * node + 1].max_val);
}
}
void push_down(int node, int start, int end) {
int mid = (start + end) / 2;
int left_node = 2 * node;
int right_node = 2 * node + 1;
if (tree[node].has_set) {
tree[left_node].max_val = tree[node].set;
tree[right_node].max_val = tree[node].set;
tree[left_node].set = tree[node].set;
tree[right_node].set = tree[node].set;
tree[left_node].add = 0;
tree[right_node].add = 0;
tree[left_node].has_set = true;
tree[right_node].has_set = true;
tree[node].has_set = false;
}
if (tree[node].add != 0) {
tree[left_node].max_val += tree[node].add;
tree[right_node].max_val += tree[node].add;
if (tree[left_node].has_set) {
tree[left_node].set += tree[node].add;
} else {
tree[left_node].add += tree[node].add;
}
if (tree[right_node].has_set) {
tree[right_node].set += tree[node].add;
} else {
tree[right_node].add += tree[node].add;
}
tree[node].add = 0;
}
}
void update_set(int node, int start, int end, int l, int r, ll x) {
if (r < start || l > end) {
return;
}
if (l <= start && end <= r) {
tree[node].max_val = x;
tree[node].set = x;
tree[node].add = 0;
tree[node].has_set = true;
return;
}
push_down(node, start, end);
int mid = (start + end) / 2;
update_set(2 * node, start, mid, l, r, x);
update_set(2 * node + 1, mid + 1, end, l, r, x);
tree[node].max_val = max(tree[2 * node].max_val, tree[2 * node + 1].max_val);
}
void update_add(int node, int start, int end, int l, int r, ll x) {
if (r < start || l > end) {
return;
}
if (l <= start && end <= r) {
tree[node].max_val += x;
if (tree[node].has_set) {
tree[node].set += x;
} else {
tree[node].add += x;
}
return;
}
push_down(node, start, end);
int mid = (start + end) / 2;
update_add(2 * node, start, mid, l, r, x);
update_add(2 * node + 1, mid + 1, end, l, r, x);
tree[node].max_val = max(tree[2 * node].max_val, tree[2 * node + 1].max_val);
}
ll query_max(int node, int start, int end, int l, int r) {
if (r < start || l > end) {
return LLONG_MIN;
}
if (l <= start && end <= r) {
return tree[node].max_val;
}
push_down(node, start, end);
int mid = (start + end) / 2;
ll left_max = query_max(2 * node, start, mid, l, r);
ll right_max = query_max(2 * node + 1, mid + 1, end, l, r);
return max(left_max, right_max);
}
public:
SegmentTree(const vector<ll>& a) {
arr = a;
n = arr.size();
tree.resize(4 * n);
build(1, 1, n);
}
void set_range(int l, int r, ll x) {
update_set(1, 1, n, l, r, x);
}
void add_range(int l, int r, ll x) {
update_add(1, 1, n, l, r, x);
}
ll get_max(int l, int r) {
return query_max(1, 1, n, l, r);
}
};
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, q;
cin >> n >> q;
vector<ll> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
SegmentTree st(a);
for (int i = 0; i < q; ++i) {
int op;
cin >> op;
if (op == 1) {
int l, r;
ll x;
cin >> l >> r >> x;
st.set_range(l, r, x);
} else if (op == 2) {
int l, r;
ll x;
cin >> l >> r >> x;
st.add_range(l, r, x);
} else if (op == 3) {
int l, r;
cin >> l >> r;
cout << st.get_max(l, r) << endl;
}
}
return 0;
}