题目链接
/*
这个方案的核心思想是通过计算不同部分连续字符对最终价值的贡献来得到最大的价值。
以下是这个实现的核心思路:
首先,使用 count 数组来记录字符 '0' 和 '1' 的个数。
对于每个字符 '0' 和 '1',都分别计算以其为分界的情况下的价值。
对于以字符 '0' 为分界的情况:
1. 找到第一个字符 '0' 出现的位置 start 和最后一个字符 '0' 出现的位置 end。
2. 计算连续字符 '0' 所对应的价值:(1 + count[c - '0']) * count[c - '0'] / 2,其中 count[c - '0'] 是字符 '0' 的个数。也就是将 '0' 之间的所有 ‘1’ 逻辑上删除。
3. 计算字符 '0' 之前连续字符 '1' 所对应的价值:(1 + start) * start / 2。
4. 计算字符 '0' 之后连续字符 '1' 所对应的价值:(s.length() - end) * (s.length() - end - 1L) / 2,这里要注意长度减1。
同样的,对于以字符 '1' 为分界的情况,重复上述步骤。
最后,从这两种情况中选择较大的一个作为最终的最大价值。
*/
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
String s = scanner.nextLine();
scanner.close();
int[] count = new int[2];
for (char c : s.toCharArray()) {
count[c - '0']++;
}
System.out.println(Math.max(getValue('0', s, count), getValue('1', s, count)));
}
public static long getValue(char c, String s, int[] count) {
int start = s.indexOf(c);
int end = s.lastIndexOf(c);
long num = (1L + count[c - '0']) * count[c - '0'] / 2;
num += (1L + start) * start / 2;
num += (s.length() - end) * (s.length() - end - 1L) / 2;
return num;
}
}
###################
# 第34题
###################
#
# 给定一个字符串 S(S.lenth < 5000),只包含 0 和 1 两个数字,下标从 1 开始,设第 i 位的价值为 vali,则vali的定义如下:
# 当 i = 1 时:val1 = 1;
# 当 i > 1 时:
# 若Si != S(i - 1):vali = 1,
# 若Si == S(i - 1):vali = val(i - 1) + 1;
# 字符串 S 的总价值等于 val1 + val2 + ... + valn(n为字符串 S 的长度)。
# 你可以任意删除字符串 S 中的任意字符,使得字符串 S 的总价值能够达到最大。
#
# 想利用一个状态转移数组用语计算,但是O(n2)超时了
# def cal_val(current, p, value):
# if current != p:
# return 1
# else:
# return value + 1
# def dp(nums):
# dp = [0]*len(nums)
# dp[0] = nums[0]
# for i in range(1, len(nums)):
# for k in range(0, i):
# dp[i] = max(dp[i], cal_val(current=nums[i], p=nums[k], value=dp[k])+dp[k])
# return max(dp)
# 参考题解java做法,因为题目中给出的条件对于连续的串奖励是非常高的,如果第三个连续则加3,因此我们鼓励更长的连续串,
# 那么只存在两种可能,最长的1串和最长的0串,我们分别计算然后计算最长串两侧的字串长度即可
def find_first_last(nums, p):
first_index = None
last_index = None
n = 0
for i in range(len(nums)):
if nums[i] == p:
if first_index is None:
first_index = i
last_index = i
n += 1
return first_index, last_index, n
def cal_num(nums, first, last, p):
past = 0
next = 0
if first == None:
return 0, 0
for i in range(0, first):
if nums[i] == p:
past += 1
for i in range(last, len(nums)):
if nums[i] == p:
next += 1
return past, next
def cal_sum(n):
return int(n*(n+1)/2)
def cal(nums):
first_0, last_0, num_0 = find_first_last(nums=nums, p=0)
first_1, last_1, num_1 = find_first_last(nums=nums, p=1)
past_0, next_0 = cal_num(nums=nums, first=first_0, last=last_0, p=1)
past_1, next_1 = cal_num(nums=nums, first=first_1, last=last_1, p=0)
sum_0 = cal_sum(num_0)+cal_sum(past_0)+cal_sum(next_0)
sum_1 = cal_sum(num_1)+cal_sum(past_1)+cal_sum(next_1)
return max(sum_0, sum_1)
s = input()
nums = []
for i in s:
try:
nums.append(int(i))
except:
print("含有非法字符")
print(cal(nums=nums))