Split LabelRepository off the MessageRepository
This commit is contained in:
		| @@ -0,0 +1,123 @@ | ||||
| /* | ||||
|  * Copyright 2017 Christian Basler | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| package ch.dissem.bitmessage.repository | ||||
|  | ||||
| import ch.dissem.bitmessage.entity.valueobject.Label | ||||
| import ch.dissem.bitmessage.ports.AbstractLabelRepository | ||||
| import ch.dissem.bitmessage.ports.LabelRepository | ||||
| import org.slf4j.LoggerFactory | ||||
| import java.sql.Connection | ||||
| import java.sql.ResultSet | ||||
| import java.sql.SQLException | ||||
| import java.util.* | ||||
|  | ||||
| class JdbcLabelRepository(private val config: JdbcConfig) : AbstractLabelRepository(), LabelRepository { | ||||
|  | ||||
|     override fun find(where: String): List<Label> { | ||||
|         try { | ||||
|             config.getConnection().use { connection -> | ||||
|                 return findLabels(connection, where) | ||||
|             } | ||||
|         } catch (e: SQLException) { | ||||
|             LOG.error(e.message, e) | ||||
|             return ArrayList() | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     override fun save(label: Label) { | ||||
|         config.getConnection().use { connection -> | ||||
|             if (label.id != null) { | ||||
|                 connection.prepareStatement("UPDATE Label SET label=?, type=?, color=?, ord=? WHERE id=?").use { ps -> | ||||
|                     ps.setString(1, label.toString()) | ||||
|                     ps.setString(2, label.type?.name) | ||||
|                     ps.setInt(3, label.color) | ||||
|                     ps.setInt(4, label.ord) | ||||
|                     ps.setInt(5, label.id as Int) | ||||
|                     ps.executeUpdate() | ||||
|                 } | ||||
|             } else { | ||||
|                 try { | ||||
|                     connection.autoCommit = false | ||||
|                     var exists = false | ||||
|                     connection.prepareStatement("SELECT COUNT(1) FROM Label WHERE label=?").use { ps -> | ||||
|                         ps.setString(1, label.toString()) | ||||
|                         val rs = ps.executeQuery() | ||||
|                         if (rs.next()) { | ||||
|                             exists = rs.getInt(1) > 0 | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     if (exists) { | ||||
|                         connection.prepareStatement("UPDATE Label SET type=?, color=?, ord=? WHERE label=?").use { ps -> | ||||
|                             ps.setString(1, label.type?.name) | ||||
|                             ps.setInt(2, label.color) | ||||
|                             ps.setInt(3, label.ord) | ||||
|                             ps.setString(4, label.toString()) | ||||
|                             ps.executeUpdate() | ||||
|                         } | ||||
|                     } else { | ||||
|                         connection.prepareStatement("INSERT INTO Label (label, type, color, ord) VALUES (?, ?, ?, ?)").use { ps -> | ||||
|                             ps.setString(1, label.toString()) | ||||
|                             ps.setString(2, label.type?.name) | ||||
|                             ps.setInt(3, label.color) | ||||
|                             ps.setInt(4, label.ord) | ||||
|                             ps.executeUpdate() | ||||
|                         } | ||||
|                     } | ||||
|                     connection.commit() | ||||
|                 } catch (e: Exception) { | ||||
|                     connection.rollback() | ||||
|                     throw e | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private fun findLabels(connection: Connection, where: String): List<Label> { | ||||
|         val result = ArrayList<Label>() | ||||
|         try { | ||||
|             connection.createStatement().use { stmt -> | ||||
|                 stmt.executeQuery("SELECT id, label, type, color, ord FROM Label WHERE $where").use { rs -> | ||||
|                     while (rs.next()) { | ||||
|                         result.add(getLabel(rs)) | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } catch (e: SQLException) { | ||||
|             LOG.error(e.message, e) | ||||
|         } | ||||
|  | ||||
|         return result | ||||
|     } | ||||
|  | ||||
|     companion object { | ||||
|         private val LOG = LoggerFactory.getLogger(JdbcLabelRepository::class.java) | ||||
|  | ||||
|         internal fun getLabel(rs: ResultSet): Label { | ||||
|             val typeName = rs.getString("type") | ||||
|             val type = if (typeName == null) { | ||||
|                 null | ||||
|             } else { | ||||
|                 Label.Type.valueOf(typeName) | ||||
|             } | ||||
|             val label = Label(rs.getString("label"), type, rs.getInt("color"), rs.getInt("ord")) | ||||
|             label.id = rs.getLong("id") | ||||
|  | ||||
|             return label | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -16,6 +16,7 @@ | ||||
|  | ||||
| package ch.dissem.bitmessage.repository | ||||
|  | ||||
| import ch.dissem.bitmessage.entity.BitmessageAddress | ||||
| import ch.dissem.bitmessage.entity.Plaintext | ||||
| import ch.dissem.bitmessage.entity.valueobject.InventoryVector | ||||
| import ch.dissem.bitmessage.entity.valueobject.Label | ||||
| @@ -29,79 +30,6 @@ import java.util.* | ||||
|  | ||||
| class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRepository(), MessageRepository { | ||||
|  | ||||
|     override fun findLabels(where: String): List<Label> { | ||||
|         try { | ||||
|             config.getConnection().use { connection -> | ||||
|                 return findLabels(connection, where) | ||||
|             } | ||||
|         } catch (e: SQLException) { | ||||
|             LOG.error(e.message, e) | ||||
|             return ArrayList() | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private fun getLabel(rs: ResultSet): Label { | ||||
|         val typeName = rs.getString("type") | ||||
|         val type = if (typeName == null) { | ||||
|             null | ||||
|         } else { | ||||
|             Label.Type.valueOf(typeName) | ||||
|         } | ||||
|         val label = Label(rs.getString("label"), type, rs.getInt("color"), rs.getInt("ord")) | ||||
|         label.id = rs.getLong("id") | ||||
|  | ||||
|         return label | ||||
|     } | ||||
|  | ||||
|     override fun save(label: Label) { | ||||
|         config.getConnection().use { connection -> | ||||
|             if (label.id != null) { | ||||
|                 connection.prepareStatement("UPDATE Label SET label=?, type=?, color=?, ord=? WHERE id=?").use { ps -> | ||||
|                     ps.setString(1, label.toString()) | ||||
|                     ps.setString(2, label.type?.name) | ||||
|                     ps.setInt(3, label.color) | ||||
|                     ps.setInt(4, label.ord) | ||||
|                     ps.setInt(5, label.id as Int) | ||||
|                     ps.executeUpdate() | ||||
|                 } | ||||
|             } else { | ||||
|                 try { | ||||
|                     connection.autoCommit = false | ||||
|                     var exists = false | ||||
|                     connection.prepareStatement("SELECT COUNT(1) FROM Label WHERE label=?").use { ps -> | ||||
|                         ps.setString(1, label.toString()) | ||||
|                         val rs = ps.executeQuery() | ||||
|                         if (rs.next()) { | ||||
|                             exists = rs.getInt(1) > 0 | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     if (exists) { | ||||
|                         connection.prepareStatement("UPDATE Label SET type=?, color=?, ord=? WHERE label=?").use { ps -> | ||||
|                             ps.setString(1, label.type?.name) | ||||
|                             ps.setInt(2, label.color) | ||||
|                             ps.setInt(3, label.ord) | ||||
|                             ps.setString(4, label.toString()) | ||||
|                             ps.executeUpdate() | ||||
|                         } | ||||
|                     } else { | ||||
|                         connection.prepareStatement("INSERT INTO Label (label, type, color, ord) VALUES (?, ?, ?, ?)").use { ps -> | ||||
|                             ps.setString(1, label.toString()) | ||||
|                             ps.setString(2, label.type?.name) | ||||
|                             ps.setInt(3, label.color) | ||||
|                             ps.setInt(4, label.ord) | ||||
|                             ps.executeUpdate() | ||||
|                         } | ||||
|                     } | ||||
|                     connection.commit() | ||||
|                 } catch (e: Exception) { | ||||
|                     connection.rollback() | ||||
|                     throw e | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     override fun countUnread(label: Label?): Int { | ||||
|         val where = if (label == null) { | ||||
|             "" | ||||
| @@ -136,26 +64,7 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep | ||||
|                         """SELECT id, iv, type, sender, recipient, data, ack_data, sent, received, initial_hash, status, ttl, retries, next_try, conversation | ||||
|                            FROM Message WHERE $where $limit""").use { rs -> | ||||
|                         while (rs.next()) { | ||||
|                             val iv = rs.getBytes("iv") | ||||
|                             val data = rs.getBinaryStream("data") | ||||
|                             val type = Plaintext.Type.valueOf(rs.getString("type")) | ||||
|                             val builder = Plaintext.readWithoutSignature(type, data) | ||||
|                             val id = rs.getLong("id") | ||||
|                             builder.id(id) | ||||
|                             builder.IV(InventoryVector.fromHash(iv)) | ||||
|                             builder.from(ctx.addressRepository.getAddress(rs.getString("sender"))!!) | ||||
|                             rs.getString("recipient")?.let { builder.to(ctx.addressRepository.getAddress(it)) } | ||||
|                             builder.ackData(rs.getBytes("ack_data")) | ||||
|                             builder.sent(rs.getObject("sent") as Long?) | ||||
|                             builder.received(rs.getObject("received") as Long?) | ||||
|                             builder.status(Plaintext.Status.valueOf(rs.getString("status"))) | ||||
|                             builder.ttl(rs.getLong("ttl")) | ||||
|                             builder.retries(rs.getInt("retries")) | ||||
|                             builder.nextTry(rs.getObject("next_try") as Long?) | ||||
|                             builder.conversation(rs.getObject("conversation") as UUID? ?: UUID.randomUUID()) | ||||
|                             builder.labels(findLabels(connection, | ||||
|                                 "id IN (SELECT label_id FROM Message_Label WHERE message_id=$id) ORDER BY ord")) | ||||
|                             val message = builder.build() | ||||
|                             val message = getMessage(connection, rs) | ||||
|                             message.initialHash = rs.getBytes("initial_hash") | ||||
|                             result.add(message) | ||||
|                         } | ||||
| @@ -165,17 +74,38 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep | ||||
|         } catch (e: SQLException) { | ||||
|             LOG.error(e.message, e) | ||||
|         } | ||||
|  | ||||
|         return result | ||||
|     } | ||||
|  | ||||
|     private fun getMessage(connection: Connection, rs: ResultSet): Plaintext { | ||||
|         return Plaintext.readWithoutSignature( | ||||
|             Plaintext.Type.valueOf(rs.getString("type")), | ||||
|             rs.getBinaryStream("data") | ||||
|         ).build { | ||||
|             id = rs.getLong("id") | ||||
|             inventoryVector = InventoryVector.fromHash(rs.getBytes("iv")) | ||||
|             from = rs.getString("sender")?.let { ctx.addressRepository.getAddress(it) ?: BitmessageAddress(it) } | ||||
|             to = rs.getString("recipient")?.let { ctx.addressRepository.getAddress(it) ?: BitmessageAddress(it) } | ||||
|             ackData = rs.getBytes("ack_data") | ||||
|             sent = rs.getObject("sent") as Long? | ||||
|             received = rs.getObject("received") as Long? | ||||
|             status = Plaintext.Status.valueOf(rs.getString("status")) | ||||
|             ttl = rs.getLong("ttl") | ||||
|             retries = rs.getInt("retries") | ||||
|             nextTry = rs.getObject("next_try") as Long? | ||||
|             conversation = rs.getObject("conversation") as UUID? ?: UUID.randomUUID() | ||||
|             labels = findLabels(connection, | ||||
|                 "id IN (SELECT label_id FROM Message_Label WHERE message_id=$id) ORDER BY ord") | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private fun findLabels(connection: Connection, where: String): List<Label> { | ||||
|         val result = ArrayList<Label>() | ||||
|         try { | ||||
|             connection.createStatement().use { stmt -> | ||||
|                 stmt.executeQuery("SELECT id, label, type, color, ord FROM Label WHERE $where").use { rs -> | ||||
|                     while (rs.next()) { | ||||
|                         result.add(getLabel(rs)) | ||||
|                         result.add(JdbcLabelRepository.getLabel(rs)) | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| @@ -258,19 +188,7 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep | ||||
|                 "status, initial_hash, ttl, retries, next_try, conversation) " + | ||||
|                 "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
|             Statement.RETURN_GENERATED_KEYS).use { ps -> | ||||
|             ps.setBytes(1, if (message.inventoryVector == null) null else message.inventoryVector!!.hash) | ||||
|             ps.setString(2, message.type.name) | ||||
|             ps.setString(3, message.from.address) | ||||
|             ps.setString(4, if (message.to == null) null else message.to!!.address) | ||||
|             writeBlob(ps, 5, message) | ||||
|             ps.setBytes(6, message.ackData) | ||||
|             ps.setObject(7, message.sent) | ||||
|             ps.setObject(8, message.received) | ||||
|             ps.setString(9, message.status.name) | ||||
|             ps.setBytes(10, message.initialHash) | ||||
|             ps.setLong(11, message.ttl) | ||||
|             ps.setInt(12, message.retries) | ||||
|             ps.setObject(13, message.nextTry) | ||||
|             prepare(ps, message) | ||||
|             ps.setObject(14, message.conversationId) | ||||
|  | ||||
|             try { | ||||
| @@ -291,24 +209,29 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep | ||||
|             "UPDATE Message SET iv=?, type=?, sender=?, recipient=?, data=?, ack_data=?, sent=?, received=?, " + | ||||
|                 "status=?, initial_hash=?, ttl=?, retries=?, next_try=? " + | ||||
|                 "WHERE id=?").use { ps -> | ||||
|             ps.setBytes(1, if (message.inventoryVector == null) null else message.inventoryVector!!.hash) | ||||
|             ps.setString(2, message.type.name) | ||||
|             ps.setString(3, message.from.address) | ||||
|             ps.setString(4, if (message.to == null) null else message.to!!.address) | ||||
|             writeBlob(ps, 5, message) | ||||
|             ps.setBytes(6, message.ackData) | ||||
|             ps.setObject(7, message.sent) | ||||
|             ps.setObject(8, message.received) | ||||
|             ps.setString(9, message.status.name) | ||||
|             ps.setBytes(10, message.initialHash) | ||||
|             ps.setLong(11, message.ttl) | ||||
|             ps.setInt(12, message.retries) | ||||
|             ps.setObject(13, message.nextTry) | ||||
|             prepare(ps, message) | ||||
|             ps.setLong(14, (message.id as Long?)!!) | ||||
|             ps.executeUpdate() | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private fun prepare(ps: PreparedStatement, message: Plaintext): Int{ | ||||
|         ps.setBytes(1, if (message.inventoryVector == null) null else message.inventoryVector!!.hash) | ||||
|         ps.setString(2, message.type.name) | ||||
|         ps.setString(3, message.from.address) | ||||
|         ps.setString(4, if (message.to == null) null else message.to!!.address) | ||||
|         writeBlob(ps, 5, message) | ||||
|         ps.setBytes(6, message.ackData) | ||||
|         ps.setObject(7, message.sent) | ||||
|         ps.setObject(8, message.received) | ||||
|         ps.setString(9, message.status.name) | ||||
|         ps.setBytes(10, message.initialHash) | ||||
|         ps.setLong(11, message.ttl) | ||||
|         ps.setInt(12, message.retries) | ||||
|         ps.setObject(13, message.nextTry) | ||||
|         return 14 | ||||
|     } | ||||
|  | ||||
|     override fun remove(message: Plaintext) { | ||||
|         try { | ||||
|             config.getConnection().use { connection -> | ||||
| @@ -332,7 +255,6 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep | ||||
|         } catch (e: SQLException) { | ||||
|             LOG.error(e.message, e) | ||||
|         } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     override fun findConversations(label: Label?): List<UUID> { | ||||
|   | ||||
| @@ -0,0 +1,49 @@ | ||||
| /* | ||||
|  * Copyright 2017 Christian Basler | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| package ch.dissem.bitmessage.repository | ||||
|  | ||||
| import ch.dissem.bitmessage.entity.valueobject.Label | ||||
| import ch.dissem.bitmessage.ports.LabelRepository | ||||
| import org.junit.Assert.assertEquals | ||||
| import org.junit.Before | ||||
| import org.junit.Test | ||||
|  | ||||
| class JdbcLabelRepositoryTest : TestBase() { | ||||
|  | ||||
|     private lateinit var repo: LabelRepository | ||||
|  | ||||
|     @Before | ||||
|     fun setUp() { | ||||
|         val config = TestJdbcConfig() | ||||
|         config.reset() | ||||
|         repo = JdbcLabelRepository(config) | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     fun `ensure labels are retrieved`() { | ||||
|         val labels = repo.getLabels() | ||||
|         assertEquals(5, labels.size.toLong()) | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     fun `ensure labels can be retrieved by type`() { | ||||
|         val labels = repo.getLabels(Label.Type.INBOX) | ||||
|         assertEquals(1, labels.size.toLong()) | ||||
|         assertEquals("Inbox", labels[0].toString()) | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -26,6 +26,7 @@ import ch.dissem.bitmessage.entity.valueobject.ExtendedEncoding | ||||
| import ch.dissem.bitmessage.entity.valueobject.Label | ||||
| import ch.dissem.bitmessage.entity.valueobject.PrivateKey | ||||
| import ch.dissem.bitmessage.entity.valueobject.extended.Message | ||||
| import ch.dissem.bitmessage.ports.LabelRepository | ||||
| import ch.dissem.bitmessage.ports.MessageRepository | ||||
| import ch.dissem.bitmessage.utils.TestUtils | ||||
| import ch.dissem.bitmessage.utils.TestUtils.mockedInternalContext | ||||
| @@ -46,6 +47,7 @@ class JdbcMessageRepositoryTest : TestBase() { | ||||
|     private lateinit var identity: BitmessageAddress | ||||
|  | ||||
|     private lateinit var repo: MessageRepository | ||||
|     private lateinit var labelRepo: LabelRepository | ||||
|  | ||||
|     private lateinit var inbox: Label | ||||
|     private lateinit var sent: Label | ||||
| @@ -58,6 +60,7 @@ class JdbcMessageRepositoryTest : TestBase() { | ||||
|         config.reset() | ||||
|         val addressRepo = JdbcAddressRepository(config) | ||||
|         repo = JdbcMessageRepository(config) | ||||
|         labelRepo = JdbcLabelRepository(config) | ||||
|         mockedInternalContext( | ||||
|             cryptography = BouncyCryptography(), | ||||
|             addressRepository = addressRepo, | ||||
| @@ -76,29 +79,16 @@ class JdbcMessageRepositoryTest : TestBase() { | ||||
|         identity = BitmessageAddress(PrivateKey(false, 1, 1000, 1000, DOES_ACK)) | ||||
|         addressRepo.save(identity) | ||||
|  | ||||
|         inbox = repo.getLabels(Label.Type.INBOX)[0] | ||||
|         sent = repo.getLabels(Label.Type.SENT)[0] | ||||
|         drafts = repo.getLabels(Label.Type.DRAFT)[0] | ||||
|         unread = repo.getLabels(Label.Type.UNREAD)[0] | ||||
|         inbox = labelRepo.getLabels(Label.Type.INBOX)[0] | ||||
|         sent = labelRepo.getLabels(Label.Type.SENT)[0] | ||||
|         drafts = labelRepo.getLabels(Label.Type.DRAFT)[0] | ||||
|         unread = labelRepo.getLabels(Label.Type.UNREAD)[0] | ||||
|  | ||||
|         addMessage(contactA, identity, Plaintext.Status.RECEIVED, inbox, unread) | ||||
|         addMessage(identity, contactA, Plaintext.Status.DRAFT, drafts) | ||||
|         addMessage(identity, contactB, Plaintext.Status.DRAFT, unread) | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     fun `ensure labels are retrieved`() { | ||||
|         val labels = repo.getLabels() | ||||
|         assertEquals(5, labels.size.toLong()) | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     fun `ensure labels can be retrieved by type`() { | ||||
|         val labels = repo.getLabels(Label.Type.INBOX) | ||||
|         assertEquals(1, labels.size.toLong()) | ||||
|         assertEquals("Inbox", labels[0].toString()) | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     fun `ensure messages can be found by label`() { | ||||
|         val messages = repo.findMessages(inbox) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user