1901 - Median Heaps

From Bitnami MediaWiki

Cerința[edit | edit source]

Se dă un vector de N numere naturale nenule, indexat de la 1.

Se cere să se raspundă la Q interogări de tipul:

  • pentru un interval [l, r] din vector, aflați costul total mimin, al egalizării tuturor elementelor din interval.

Într-un interval [l, r], puteți crește sau micșora fiecare element cu costul x unde x este diferența dintre valoarea nouă și valoarea inițială. Costul total este suma acestor costuri.

Date de intrare[edit | edit source]

  • pe prima linie numărul N.
  • pe a doua linie N numere naturale nenule : , , … .
  • pe a treia linie numărul Q.
  • pe următoarele Q linii 2 numere: l, r.

Date de ieșire[edit | edit source]

Se vor afișa Q numere pe fiecare linie, reprezentând constul total minim, al fiecărui interval l r.

Restricții și precizări[edit | edit source]

  • 1 ≤ N, Q ≤ 100.000
  • 1 ≤ l ≤ r ≤ 100.000
  • 1 ≤  ≤ 1.000.000.000

Exemplu:[edit | edit source]

Intrare

9
3 2 16 15 12 6 20 4 5
3
2 7
2 2
3 8

Ieșire

31
0
29

Explicație[edit | edit source]

Pentru intervalul [2, 7], costul total minim este 31, deoarece egalizăm fiecare număr din interval cu 12.

Pentru intervalul [2, 2], costul total minim este 0, deoarece avem un singur element în interval.

Pentru intervalul [3, 8], costul total minim este 29, deoarece egalizăm fiecare număr din interval cu 12

Rezolvare[edit | edit source]

<syntaxhighlight lang="python" line="1"> import math

Nmax = 100001

class Query:

   def __init__(self, st, dr, pos):
       self.st = st
       self.dr = dr
       self.pos = pos

def update(aib, aibs, pos, val1, val2):

   while pos < Nmax:
       aib[pos] += val1
       aibs[pos] += val2
       pos += pos & -pos

def query(aib, aibs, pos, c):

   sum_val = 0
   s = 0
   while pos > 0:
       sum_val += aib[pos]
       if aibs is not None:
           s += aibs[pos]
       pos -= pos & -pos
   return sum_val if c == 1 else s

def add(k, N, V, S, aib, aibs):

   update(aib, aibs, N[k], 1, V[S[N[k]]])

def delete(k, N, V, S, aib, aibs):

   update(aib, aibs, N[k], -1, -V[S[N[k]]])

def binary_search(pos, n, aib):

   st, dr = 1, n
   ans = -1
   while st <= dr:
       mid = (st + dr) // 2
       q = query(aib, None, mid, 1)
       if q >= pos:
           dr = mid - 1
           ans = mid
       else:
           st = mid + 1
   return ans

def verify_restrictions(n, q, ranges):

   if not (1 <= n <= 100000 and 1 <= q <= 100000):
       return False
   for l, r in ranges:
       if not (1 <= l <= r <= 100000):
           return False
   return True

def main():

   n = int(input())
   V = [0] * (n + 1)
   
   values = list(map(int, input().split()))
   for i in range(1, n + 1):
       V[i] = values[i-1]
   
   q = int(input())
   ranges = []
   Q = []
   for i in range(1, q + 1):
       st, dr = map(int, input().split())
       ranges.append((st, dr))
       Q.append(Query(st, dr, i))
   if not verify_restrictions(n, q, ranges):
       print("Datele nu corespund restrictiilor impuse")
       return
   block = int(math.sqrt(n))
   S = list(range(n + 1))
   S[1:] = sorted(S[1:], key=lambda x: V[x])
   N = [0] * (n + 1)
   for i in range(1, n + 1):
       N[S[i]] = i
   Q.sort(key=lambda x: (x.dr // block, x.st if x.dr // block % 2 == 0 else -x.st))
   aib = [0] * Nmax
   aibs = [0] * Nmax
   Rez = [0] * (q + 1)
   
   st, dr = 1, 0
   for i in range(1, q + 1):
       s, d = Q[i-1].st, Q[i-1].dr
       while st < s:
           delete(st, N, V, S, aib, aibs)
           st += 1
       while st > s:
           st -= 1
           add(st, N, V, S, aib, aibs)
       while dr < d:
           dr += 1
           add(dr, N, V, S, aib, aibs)
       while dr > d:
           delete(dr, N, V, S, aib, aibs)
           dr -= 1
       poss = (d - s + 2) // 2
       ans = binary_search(poss, n, aib)
       s1 = query(aib, aibs, n, 2)
       h1 = query(aib, None, n, 1)
       s2 = query(aib, aibs, ans, 2)
       h2 = query(aib, None, ans, 1)
       s1 -= s2
       h1 -= h2
       Rez[Q[i-1].pos] = s1 - h1 * V[S[ans]] + (-s2 + h2 * V[S[ans]])
   for i in range(1, q + 1):
       print(Rez[i])

if __name__ == "__main__":

   main()

</syntaxhighlight>