1901 - Median Heaps
De la Universitas MediaWiki
Cerința
Se dă un vector de N
numere naturale nenule, indexat de la 1
.
Se cere să se raspundă la Q
interogări de tipul:
- pentru un interval
[l, r]
din vector, aflați costul total mimin, al egalizării tuturor elementelor din interval.
Într-un interval [l, r]
, puteți crește sau micșora fiecare element cu costul x
unde x
este diferența dintre valoarea nouă și valoarea inițială. Costul total este suma acestor costuri.
Date de intrare
- pe prima linie numărul
N
. - pe a doua linie
N
numere naturale nenule : , , … . - pe a treia linie numărul
Q
. - pe următoarele
Q
linii2
numere:l, r
.
Date de ieșire
Se vor afișa Q
numere pe fiecare linie, reprezentând constul total minim, al fiecărui interval l r
.
Restricții și precizări
1 ≤ N, Q ≤ 100.000
1 ≤ l ≤ r ≤ 100.000
1 ≤
≤ 1.000.000.000
Exemplu:
Intrare
9 3 2 16 15 12 6 20 4 5 3 2 7 2 2 3 8
Ieșire
31 0 29
Explicație
Pentru intervalul [2, 7]
, costul total minim este 31
, deoarece egalizăm fiecare număr din interval cu 12
.
Pentru intervalul [2, 2]
, costul total minim este 0
, deoarece avem un singur element în interval.
Pentru intervalul [3, 8]
, costul total minim este 29
, deoarece egalizăm fiecare număr din interval cu 12
Rezolvare
import math
Nmax = 100001
class Query:
def __init__(self, st, dr, pos):
self.st = st
self.dr = dr
self.pos = pos
def update(aib, aibs, pos, val1, val2):
while pos < Nmax:
aib[pos] += val1
aibs[pos] += val2
pos += pos & -pos
def query(aib, aibs, pos, c):
sum_val = 0
s = 0
while pos > 0:
sum_val += aib[pos]
if aibs is not None:
s += aibs[pos]
pos -= pos & -pos
return sum_val if c == 1 else s
def add(k, N, V, S, aib, aibs):
update(aib, aibs, N[k], 1, V[S[N[k]]])
def delete(k, N, V, S, aib, aibs):
update(aib, aibs, N[k], -1, -V[S[N[k]]])
def binary_search(pos, n, aib):
st, dr = 1, n
ans = -1
while st <= dr:
mid = (st + dr) // 2
q = query(aib, None, mid, 1)
if q >= pos:
dr = mid - 1
ans = mid
else:
st = mid + 1
return ans
def verify_restrictions(n, q, ranges):
if not (1 <= n <= 100000 and 1 <= q <= 100000):
return False
for l, r in ranges:
if not (1 <= l <= r <= 100000):
return False
return True
def main():
n = int(input())
V = [0] * (n + 1)
values = list(map(int, input().split()))
for i in range(1, n + 1):
V[i] = values[i-1]
q = int(input())
ranges = []
Q = []
for i in range(1, q + 1):
st, dr = map(int, input().split())
ranges.append((st, dr))
Q.append(Query(st, dr, i))
if not verify_restrictions(n, q, ranges):
print("Datele nu corespund restrictiilor impuse")
return
block = int(math.sqrt(n))
S = list(range(n + 1))
S[1:] = sorted(S[1:], key=lambda x: V[x])
N = [0] * (n + 1)
for i in range(1, n + 1):
N[S[i]] = i
Q.sort(key=lambda x: (x.dr // block, x.st if x.dr // block % 2 == 0 else -x.st))
aib = [0] * Nmax
aibs = [0] * Nmax
Rez = [0] * (q + 1)
st, dr = 1, 0
for i in range(1, q + 1):
s, d = Q[i-1].st, Q[i-1].dr
while st < s:
delete(st, N, V, S, aib, aibs)
st += 1
while st > s:
st -= 1
add(st, N, V, S, aib, aibs)
while dr < d:
dr += 1
add(dr, N, V, S, aib, aibs)
while dr > d:
delete(dr, N, V, S, aib, aibs)
dr -= 1
poss = (d - s + 2) // 2
ans = binary_search(poss, n, aib)
s1 = query(aib, aibs, n, 2)
h1 = query(aib, None, n, 1)
s2 = query(aib, aibs, ans, 2)
h2 = query(aib, None, ans, 1)
s1 -= s2
h1 -= h2
Rez[Q[i-1].pos] = s1 - h1 * V[S[ans]] + (-s2 + h2 * V[S[ans]])
for i in range(1, q + 1):
print(Rez[i])
if __name__ == "__main__":
main()