找到第N大的数,或者第N小的数。
先排序再nums[n]就是答案,不过nlgn
要求必须n
就得用quick select
其实是quick sort的变种。
quick sort每次选PIVOT,然后分成2边,再分别SORT两边。
这里就没必要了,我们要找第N大的,每次SORT完,看N在PIVOT左边还是右边,然后SORT相应的一边就行。
目的不在于真的SORT这个数组,而是希望PIVOT正好是要找的数。
然后为了避免quicksort中最倒霉的情况,PIVOT必须好好选,可以选L R M 3个值中的中间值。。不能任意选L或者R。
假设已经有P值了,先把P和R交换。
然后记录坑的位置,TEMP = L
从L 遍历到 R-1
nums[L] < pivot就把当前值放坑里,坑的索引++
最后的结果就是坑的索引之前的数,小于pivot
此时nums[R]放的是PIVOT,和坑再交换一下
此时数列按照nums[temp]左右分了
如果temp == nums.length - k 第N大
就算找到了 return nums[temp]
temp < nums.length - k 说明要找的在TEMP右边
那么 quicksort(temp+1,R)
否则 quicksort(L,temp-1)
最后就找到了
public int findKthLargest(int[] nums, int k)
{
if(nums.length == 0) return -1;
if(nums.length == 1) return nums[0];
return quickSelect(nums,0,nums.length-1,k);
}
public int quickSelect(int[] nums, int L, int R, int K)
{
int pivotIndex = generatePivot(nums,L,R);
int P = nums[pivotIndex];
swap(nums,pivotIndex,R);
int temp = L;
for(int n = L;n < R;n++)
{
if(nums[n] < P)
{
swap(nums,n,temp);
temp++;
}
}
swap(nums,R,temp);
if(temp == nums.length-K) return nums[temp];
else if( temp < nums.length-K) return quickSelect(nums,temp+1,R,K);
else return quickSelect(nums,L,temp-1,K);
}
public void swap(int[] nums, int L, int R)
{
//System.out.println(L);
int temp = nums[L];
nums[L] = nums[R];
nums[R] = temp;
}
public int generatePivot(int[] nums,int L, int R)
{
int M = (L+R)/2;
if(nums[L] < nums[R])
{
if(nums[R] < nums[M]) return R;
else return nums[L] > nums[M] ? L : M;
}
else //R < L
{
if(nums[L] < nums[M]) return L;
else return nums[M] > nums[R] ? M : R;
}
}
这个思路很重要,后面的wiggle还要用
二刷。
这个题还有印象,用的quick select,quick sort的一部分。
时间上是需要证明为什么是O(n)。。
最坏的情况是pivot每次选极端值,最后就是n2,这个不用说了。
最好的情况是我们每次都选的是中间值,那最终结果就是:
n/2 + n/4 + n/8 + .. + n/n = n - 1
就是O(n)。
为了保证我们取值合适,得用适当的方式来选取pivot,而不是随机从里面区间抓一个。
public class Solution {
public int findKthLargest(int[] nums, int k) {
if (nums.length == 1) return nums[0];
return quickSelect(0,nums.length-1,k,nums);
}
public int quickSelect(int l, int r, int k, int[] nums) {
int m = getMid(l,r,nums);
int target = nums[m];
swap(nums,r,m);
int left = l;
int i = l;
while (i < r) {
if (nums[i] < target) {
swap(nums, i, left);
left++;
}
i++;
}
swap(nums,left,r);
if (left == nums.length - k) {
return nums[left];
} else if (left > nums.length - k) {
return quickSelect(l, left-1, k, nums);
} else {
return quickSelect(left+1, r, k, nums);
}
}
public int getMid(int l, int r, int[] nums) {
int a = nums[l];
int b = nums[r];
int m = l + (r - l) / 2;
int c = nums[m];
if (a > b) {
if (b > c) return r;
else return a > c? m : l;
} else {
if (a > c) return l;
else return b > c? m: r;
}
}
public void swap(int[] nums, int a, int b) {
int temp = nums[a];
nums[a] = nums[b];
nums[b] = temp;
}
}
看讨论区发现自己发的POST,我他妈二刷居然不如一刷写的好。。。
https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention
三刷。
首先用正常的PQ来做,维持大小为K的minHeap,所有元素往里面加,多了就POLL出去。。
最后顶上的就是要求的。
Time: O(n lgk)
Space: O(k)
public class Solution {
public int findKthLargest(int[] nums, int k) {
if (nums.length == 1) return nums[0];
PriorityQueue<Integer> pq = new PriorityQueue<Integer>(k);
for (int i : nums) {
pq.offer(i);
if (pq.size() > k) {
pq.poll();
}
}
return pq.poll();
}
}
然后quick select, 基本想法是,只继续SORT可能存在K的那半部分。
每次“尽量"选一个合适的pivot值,然后进行partition,这个地方刚卡了一下= = 记录的应该是小于pivot的而不是大于pivot的。。
Time: O(n) average.. O(n²) worst..
public class Solution {
public int findKthLargest(int[] nums, int k) {
if (nums.length == 1) return nums[0];
return quickSelect(nums, k, 0, nums.length - 1);
}
public int quickSelect(int[] nums, int k, int l, int r) {
int m = betterMid(nums, l ,r);
int pivot = nums[m];
swap(nums, m, r);
int smaller = l;
for (int i = l; i < r; i++) {
if (nums[i] < pivot) {
swap(nums, i, smaller ++);
}
}
swap(nums, smaller, r);
if (smaller + k == nums.length) {
return nums[smaller];
} else if (smaller + k > nums.length) {
return quickSelect(nums, k, l, smaller - 1);
} else {
return quickSelect(nums, k, smaller + 1, r);
}
}
public void swap(int[] nums, int indexA, int indexB) {
int temp = nums[indexA];
nums[indexA] = nums[indexB];
nums[indexB] = temp;
}
public int betterMid(int[] nums, int l, int r) {
int m = l + (r - l) / 2;
if (nums[l] > nums[r]) {
if (nums[m] > nums[r]) {
return nums[l] > nums[m] ? m : l;
} else {
return r;
}
} else {
if (nums[m] > nums[l]) {
return nums[r] > nums[m] ? m : r;
} else {
return l;
}
}
}
}