Skip to content

Commit 7a341b2

Browse files
Merge pull request #67 from jonathanrocher/fix/stage6_models
Fix: stage6 models
2 parents e1e2038 + 8ee7d38 commit 7a341b2

10 files changed

Lines changed: 138 additions & 124 deletions

File tree

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# General imports
2-
import os
2+
from os.path import splitext
33
import PIL.Image
44
from PIL.ExifTags import TAGS
55
from skimage import data
66
from skimage.feature import Cascade
77
import numpy as np
88

99
# ETS imports
10-
from traits.api import Array, cached_property, Dict, File, HasStrictTraits, \
11-
List, Property
10+
from traits.api import (
11+
Array, cached_property, Dict, File, HasStrictTraits, List, Property
12+
)
1213

13-
SUPPORTED_FORMATS = [".png", ".jpg", ".jpeg"]
14+
SUPPORTED_FORMATS = [".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"]
1415

1516

1617
class ImageFile(HasStrictTraits):
@@ -24,43 +25,40 @@ class ImageFile(HasStrictTraits):
2425

2526
faces = List
2627

27-
def to_array(self):
28-
file_ext = os.path.splitext(self.filepath)[1].lower()
29-
if not self.filepath or file_ext not in SUPPORTED_FORMATS:
30-
return np.array([])
31-
32-
with PIL.Image.open(self.filepath) as img:
33-
return np.asarray(img)
28+
def _is_valid_file(self):
29+
return (
30+
bool(self.filepath) and
31+
splitext(self.filepath)[1].lower() in SUPPORTED_FORMATS
32+
)
3433

3534
@cached_property
3635
def _get_data(self):
37-
return self.to_array()
36+
if not self._is_valid_file():
37+
return np.array([])
38+
with PIL.Image.open(self.filepath) as img:
39+
return np.asarray(img)
3840

3941
@cached_property
4042
def _get_metadata(self):
41-
file_ext = os.path.splitext(self.filepath)[1].lower()
42-
if not self.filepath or file_ext not in SUPPORTED_FORMATS:
43+
if not self._is_valid_file():
4344
return {}
44-
4545
with PIL.Image.open(self.filepath) as img:
4646
exif = img._getexif()
47-
48-
if exif:
49-
return {TAGS[k]: v for k, v in exif.items()
50-
if k in TAGS}
51-
else:
47+
if not exif:
5248
return {}
49+
return {TAGS[k]: v for k, v in exif.items() if k in TAGS}
5350

54-
def detect_faces(self, scale_factor=1.2, step_ratio=1, min_size=60,
55-
max_size=600):
56-
""" Detect faces in the image.
57-
"""
51+
def detect_faces(self):
52+
# Load the trained file from the module root.
5853
trained_file = data.lbp_frontal_face_cascade_filename()
54+
55+
# Initialize the detector cascade.
5956
detector = Cascade(trained_file)
60-
faces = detector.detect_multi_scale(img=self.data,
61-
scale_factor=scale_factor,
62-
step_ratio=step_ratio,
63-
min_size=(min_size, min_size),
64-
max_size=(max_size, max_size))
65-
self.faces = faces
66-
return faces
57+
58+
detected = detector.detect_multi_scale(img=self.data,
59+
scale_factor=1.2,
60+
step_ratio=1,
61+
min_size=(60, 60),
62+
max_size=(600, 600))
63+
self.faces = detected
64+
return self.faces
Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
# General imports
2-
import os
2+
import glob
3+
from os.path import basename, expanduser, isdir
34

4-
import pandas as pd
55
import numpy as np
6+
import pandas as pd
67

78
# ETS imports
8-
from traits.api import Directory, Event, HasStrictTraits, Instance
9+
from traits.api import (
10+
Directory, Event, HasStrictTraits, Instance, List, observe,
11+
)
912

1013
# Local imports
11-
from .image_file import ImageFile, SUPPORTED_FORMATS
14+
from pycasa.model.image_file import ImageFile, SUPPORTED_FORMATS
1215

1316
FILENAME_COL = "filename"
14-
1517
NUM_FACE_COL = "Num. faces"
1618

1719

1820
class ImageFolder(HasStrictTraits):
19-
""" Model to hold an image folder.
21+
""" Model for a folder of images.
2022
"""
21-
path = Directory
23+
directory = Directory(expanduser("~"))
24+
25+
images = List(Instance(ImageFile))
2226

2327
data = Instance(pd.DataFrame)
2428

@@ -27,39 +31,39 @@ class ImageFolder(HasStrictTraits):
2731
def __init__(self, **traits):
2832
# Don't forget this!
2933
super(ImageFolder, self).__init__(**traits)
30-
if not os.path.isdir(self.path):
31-
msg = f"Unable to create an ImageFolder from {self.path} since" \
32-
f" it is not a valid directory."
34+
if not isdir(self.directory):
35+
msg = f"The provided directory isn't a real directory: " \
36+
f"{self.directory}"
3337
raise ValueError(msg)
38+
self.data = self._create_metadata_df()
3439

35-
self.data = self.to_dataframe()
40+
@observe("directory")
41+
def _update_images(self, event):
42+
self.images = [
43+
ImageFile(filepath=file)
44+
for fmt in SUPPORTED_FORMATS
45+
for file in glob.glob(f"{self.directory}/*{fmt}")
46+
]
3647

37-
def to_dataframe(self):
38-
if not self.path:
39-
return pd.DataFrame({FILENAME_COL: [], NUM_FACE_COL: []})
48+
@observe("images.items")
49+
def _update_metadata(self, event):
50+
self.data = self._create_metadata_df()
4051

41-
data = []
42-
for filename in os.listdir(self.path):
43-
file_ext = os.path.splitext(filename)[1].lower()
44-
if file_ext in SUPPORTED_FORMATS:
45-
filepath = os.path.join(self.path, filename)
46-
img_file = ImageFile(filepath=filepath)
47-
file_data = {FILENAME_COL: filename, NUM_FACE_COL: np.nan}
48-
try:
49-
file_data.update(img_file.metadata)
50-
except Exception:
51-
pass
52-
data.append(file_data)
52+
def _create_metadata_df(self):
53+
if not self.images:
54+
return pd.DataFrame({FILENAME_COL: [], NUM_FACE_COL: []})
55+
return pd.DataFrame([
56+
{
57+
FILENAME_COL: basename(img.filepath),
58+
NUM_FACE_COL: np.nan,
59+
**img.metadata
5360

54-
return pd.DataFrame(data)
61+
}
62+
for img in self.images
63+
])
5564

5665
def compute_num_faces(self, **kwargs):
57-
cols = list(self.data.columns)
58-
for i, filename in enumerate(self.data[FILENAME_COL]):
59-
print(filename)
60-
filepath = os.path.join(self.path, filename)
61-
img_file = ImageFile(filepath=filepath)
62-
faces = img_file.detect_faces(**kwargs)
63-
j = cols.index(NUM_FACE_COL)
64-
self.data.iloc[i, j] = len(faces)
66+
for i, image in enumerate(self.images):
67+
faces = image.detect_faces(**kwargs)
68+
self.data[NUM_FACE_COL].iat[i] = len(faces)
6569
self.data_updated = True

stage6_branded_application/pycasa/model/tests/test_image_file.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,28 @@ class TestImageFile(TestCase):
1818
def test_no_image_file(self):
1919
img = ImageFile()
2020
self.assertEqual(img.metadata, {})
21-
data = img.to_array()
22-
self.assertIsInstance(data, np.ndarray)
23-
self.assertEqual(data.shape, (0,))
21+
self.assertIsInstance(img.data, np.ndarray)
22+
self.assertEqual(img.data.shape, (0,))
2423

2524
def test_bad_type_image_file(self):
2625
img = ImageFile(filepath=__file__)
2726
self.assertEqual(img.metadata, {})
28-
data = img.to_array()
29-
self.assertIsInstance(data, np.ndarray)
30-
self.assertEqual(data.shape, (0,))
27+
self.assertIsInstance(img.data, np.ndarray)
28+
self.assertEqual(img.data.shape, (0,))
3129

3230
def test_image_metadata(self):
3331
img = ImageFile(filepath=SAMPLE_IMG1)
3432
self.assertNotEqual(img.metadata, {})
3533
for key in ['ExifVersion', 'ExifImageWidth', 'ExifImageHeight']:
3634
self.assertIn(key, img.metadata.keys())
37-
data = img.to_array()
3835
expected_shape = (img.metadata['ExifImageHeight'],
3936
img.metadata['ExifImageWidth'], 3)
40-
self.assertEqual(data.shape, expected_shape)
37+
self.assertEqual(img.data.shape, expected_shape)
4138

4239
def test_image_data(self):
4340
img = ImageFile(filepath=SAMPLE_IMG1)
4441
self.assertNotIn(0, img.data.shape)
45-
np.testing.assert_almost_equal(img.data, img.to_array())
42+
np.testing.assert_almost_equal(img.data, img.data)
4643
self.assertIsInstance(img.data, np.ndarray)
4744
self.assertNotEqual(img.data.mean(), 0)
4845

stage6_branded_application/pycasa/model/tests/test_image_folder.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,19 @@
1717
class TestImageFolder(TestCase):
1818
def test_no_folder(self):
1919
with self.assertRaises(ValueError):
20-
ImageFolder()
20+
ImageFolder(directory="path/to/nonexistent/dir")
2121

2222
def test_with_file(self):
2323
with self.assertRaises(ValueError):
24-
ImageFolder(path=__file__)
24+
ImageFolder(directory=__file__)
2525

2626
def test_empty_folder(self):
27-
img = ImageFolder(path=HERE)
28-
data = img.to_dataframe()
29-
self.assertIsInstance(data, pd.DataFrame)
30-
self.assertEqual(len(data), 0)
27+
img_folder = ImageFolder(directory=HERE)
28+
self.assertIsInstance(img_folder.data, pd.DataFrame)
29+
self.assertEqual(len(img_folder.data), 0)
3130

3231
def test_real_folder(self):
33-
img = ImageFolder(path=SAMPLE_IMG_DIR)
34-
data = img.to_dataframe()
35-
self.assertEqual(len(data), 2)
32+
img_folder = ImageFolder(directory=SAMPLE_IMG_DIR)
33+
self.assertEqual(len(img_folder.data), 2)
3634
for key in ['ExifVersion', 'ExifImageWidth', 'ExifImageHeight']:
37-
self.assertIn(key, data.columns)
35+
self.assertIn(key, img_folder.data.columns)

stage6_branded_application/pycasa/ui/image_file_view.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,18 @@ class ImageFileView(ModelView):
3434
def build_mpl_figure(self, event):
3535
figure = Figure()
3636
axes = figure.add_subplot(111)
37-
axes.imshow(self.model.to_array())
37+
axes.imshow(self.model.data)
3838
self.figure = figure
3939

40-
def _detect_button_fired(self):
40+
@observe("detect_button")
41+
def _detect_button_fired(self, event):
4142
self.model.detect_faces()
4243

4344
@observe("model.faces")
4445
def update_mpl_figure_with_faces(self, events):
4546
figure = Figure()
4647
axes = figure.add_subplot(111)
47-
axes.imshow(self.model.to_array())
48+
axes.imshow(self.model.data)
4849

4950
for face in self.model.faces:
5051
axes.add_patch(

stage6_branded_application/pycasa/ui/image_folder_editor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def create(self, parent):
3030
# -------------------------------------------------------------------------
3131

3232
def _get_name(self):
33-
return self.obj.path[:25]
33+
return self.obj.directory[:25]
3434

3535
def _get_tooltip(self):
36-
return self.obj.path
36+
return self.obj.directory

0 commit comments

Comments
 (0)