YuXin_Liu/图像分割/segment/generate_contour.py

117 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
# # 创建一个空白图像假设大小为100x100
# height, width = 100, 100
# segmentation_result = np.zeros((height, width), dtype=np.uint8)
# # 创建一些分割区域
# segmentation_result[10:30, 10:30] = 50 # 区域1
# segmentation_result[40:60, 40:60] = 100 # 区域2
# segmentation_result[70:90, 70:90] = 150 # 区域3
# segmentation_result[20:50, 70:90] = 200 # 区域4
# # 保存图像
# cv2.imwrite('segmentation_result.png', segmentation_result)
# segmentation_result = cv2.imread('segmentation_result.png', cv2.IMREAD_GRAYSCALE)
# height, width = segmentation_result.shape
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# unique_labels = np.unique(segmentation_result)
# heights = np.linspace(1, 10, len(unique_labels)) # 高度范围从1到10可以根据需要调整
# for i, label in enumerate(unique_labels):
# mask = (segmentation_result == label)
# x, y = np.meshgrid(np.arange(width), np.arange(height))
# x = x[mask]
# y = y[mask]
# z = np.zeros_like(x)
# dz = np.full_like(x, heights[i])
# ax.bar3d(x, y, z, 1, 1, dz, shade=True)
# ax.set_xlabel('X axis')
# ax.set_ylabel('Y axis')
# ax.set_zlabel('Height')
# plt.show()
def categorize_pixels(image_path, output_dir):
# 读取图像
image = cv2.imread(image_path)
if image is None:
print("Error: Unable to read image.")
return
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 获取图像的高度和宽度
height, width, _ = image.shape
# 创建类别掩码
categories = np.zeros((height, width), dtype=np.uint8)
# 遍历每个像素并分类
for y in range(height):
for x in range(width):
r, g, b = image[y, x]
rgb_sum = int(r) + int(g) + int(b)
# if b != 0:
# continue
if rgb_sum <= 0: # label 1
categories[y, x] = 50
elif rgb_sum <= 120: # label 2
categories[y, x] = 100
elif rgb_sum < 200: # label 3
categories[y, x] = 150
else: # label 4
categories[y, x] = 200
cv2.imwrite("deeptmp.png",categories)
segmentation_result = cv2.imread("./deeptmp.png", cv2.IMREAD_GRAYSCALE)
height, width = segmentation_result.shape
print("正在生成三维图...")
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
unique_labels = np.unique(segmentation_result)
heights = [10, 5, 3, 1]
#heights = np.linspace(10, 1, len(unique_labels)) # 高度范围从1到10可以根据需要调整
for i, label in enumerate(unique_labels):
mask = (segmentation_result == label)
x, y = np.meshgrid(np.arange(width), np.arange(height))
x = x[mask]
y = y[mask]
z = np.zeros_like(x)
dz = np.full_like(x, heights[i])
ax.bar3d(x, y, z, 1, 1, dz, shade=True)
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Height')
plt.show()
if __name__ == "__main__":
image_path = "./fill.png" # 替换为你的图像路径
output_dir = "./"
categorize_pixels(image_path, output_dir)
print(f"Pixel categorization complete. Results saved in {output_dir}")