1901 - Median Heaps

De la Universitas MediaWiki
Versiunea pentru tipărire nu mai este suportată și poate avea erori de randare. Vă rugăm să vă actualizați bookmarkurile browserului și să folosiți funcția implicită de tipărire a browserului.

Cerința

Se dă un vector de N numere naturale nenule, indexat de la 1.

Se cere să se raspundă la Q interogări de tipul:

  • pentru un interval [l, r] din vector, aflați costul total mimin, al egalizării tuturor elementelor din interval.

Într-un interval [l, r], puteți crește sau micșora fiecare element cu costul x unde x este diferența dintre valoarea nouă și valoarea inițială. Costul total este suma acestor costuri.

Date de intrare

  • pe prima linie numărul N.
  • pe a doua linie N numere naturale nenule : , , … .
  • pe a treia linie numărul Q.
  • pe următoarele Q linii 2 numere: l, r.

Date de ieșire

Se vor afișa Q numere pe fiecare linie, reprezentând constul total minim, al fiecărui interval l r.

Restricții și precizări

  • 1 ≤ N, Q ≤ 100.000
  • 1 ≤ l ≤ r ≤ 100.000
  • 1 ≤  ≤ 1.000.000.000

Exemplu:

Intrare

9
3 2 16 15 12 6 20 4 5
3
2 7
2 2
3 8

Ieșire

31
0
29

Explicație

Pentru intervalul [2, 7], costul total minim este 31, deoarece egalizăm fiecare număr din interval cu 12.

Pentru intervalul [2, 2], costul total minim este 0, deoarece avem un singur element în interval.

Pentru intervalul [3, 8], costul total minim este 29, deoarece egalizăm fiecare număr din interval cu 12

Rezolvare

import math

Nmax = 100001

class Query:
    def __init__(self, st, dr, pos):
        self.st = st
        self.dr = dr
        self.pos = pos

def update(aib, aibs, pos, val1, val2):
    while pos < Nmax:
        aib[pos] += val1
        aibs[pos] += val2
        pos += pos & -pos

def query(aib, aibs, pos, c):
    sum_val = 0
    s = 0
    while pos > 0:
        sum_val += aib[pos]
        if aibs is not None:
            s += aibs[pos]
        pos -= pos & -pos
    return sum_val if c == 1 else s

def add(k, N, V, S, aib, aibs):
    update(aib, aibs, N[k], 1, V[S[N[k]]])

def delete(k, N, V, S, aib, aibs):
    update(aib, aibs, N[k], -1, -V[S[N[k]]])

def binary_search(pos, n, aib):
    st, dr = 1, n
    ans = -1
    while st <= dr:
        mid = (st + dr) // 2
        q = query(aib, None, mid, 1)
        if q >= pos:
            dr = mid - 1
            ans = mid
        else:
            st = mid + 1
    return ans

def verify_restrictions(n, q, ranges):
    if not (1 <= n <= 100000 and 1 <= q <= 100000):
        return False
    for l, r in ranges:
        if not (1 <= l <= r <= 100000):
            return False
    return True

def main():
    n = int(input())
    V = [0] * (n + 1)
    
    values = list(map(int, input().split()))
    for i in range(1, n + 1):
        V[i] = values[i-1]
    
    q = int(input())
    ranges = []
    Q = []
    for i in range(1, q + 1):
        st, dr = map(int, input().split())
        ranges.append((st, dr))
        Q.append(Query(st, dr, i))

    if not verify_restrictions(n, q, ranges):
        print("Datele nu corespund restrictiilor impuse")
        return

    block = int(math.sqrt(n))
    S = list(range(n + 1))
    S[1:] = sorted(S[1:], key=lambda x: V[x])
    N = [0] * (n + 1)
    for i in range(1, n + 1):
        N[S[i]] = i

    Q.sort(key=lambda x: (x.dr // block, x.st if x.dr // block % 2 == 0 else -x.st))
    aib = [0] * Nmax
    aibs = [0] * Nmax
    Rez = [0] * (q + 1)
    
    st, dr = 1, 0
    for i in range(1, q + 1):
        s, d = Q[i-1].st, Q[i-1].dr
        while st < s:
            delete(st, N, V, S, aib, aibs)
            st += 1
        while st > s:
            st -= 1
            add(st, N, V, S, aib, aibs)
        while dr < d:
            dr += 1
            add(dr, N, V, S, aib, aibs)
        while dr > d:
            delete(dr, N, V, S, aib, aibs)
            dr -= 1

        poss = (d - s + 2) // 2
        ans = binary_search(poss, n, aib)
        s1 = query(aib, aibs, n, 2)
        h1 = query(aib, None, n, 1)
        s2 = query(aib, aibs, ans, 2)
        h2 = query(aib, None, ans, 1)
        s1 -= s2
        h1 -= h2

        Rez[Q[i-1].pos] = s1 - h1 * V[S[ans]] + (-s2 + h2 * V[S[ans]])

    for i in range(1, q + 1):
        print(Rez[i])

if __name__ == "__main__":
    main()